From e90e0bdfa2ef8a3b1d0579759247d1516f093821 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Sat, 2 Mar 2019 09:01:44 +0000 Subject: [PATCH] fix for gpu grad. test=develop --- paddle/fluid/operators/kldiv_loss_op.cc | 2 +- paddle/fluid/operators/kldiv_loss_op.h | 20 +++++++++++++++---- .../tests/unittests/test_kldiv_loss_op.py | 13 ++++++------ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index f1b35351274..a65bb3bade3 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -33,7 +33,7 @@ class KLDivLossOp : public framework::OperatorWithKernel { auto dim_target = ctx->GetInputDim("Target"); PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), "Input(X) rank and Input(Target) rank should be same."); - for (size_t i = 0; i < dim_x.size(); i++) { + for (int i = 0; i < dim_x.size(); i++) { PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i], "Input(X) and Input(Target) should in same shape."); } diff --git a/paddle/fluid/operators/kldiv_loss_op.h b/paddle/fluid/operators/kldiv_loss_op.h index fa53753d0ed..f262cfbb5fb 100644 --- a/paddle/fluid/operators/kldiv_loss_op.h +++ b/paddle/fluid/operators/kldiv_loss_op.h @@ -30,7 +30,7 @@ struct KLDivLossForward { HOSTDEVICE KLDivLossForward() {} HOSTDEVICE T operator()(const T& target, const T& input) const { - if (target < 0) { + if (target <= 0) { return 0; } else { return target * (std::log(target) - input); @@ -38,6 +38,19 @@ struct KLDivLossForward { } }; +template +struct KLDivLossBackward { + HOSTDEVICE KLDivLossBackward() {} + + HOSTDEVICE T operator()(const T& target, const T& grad) const { + if (target <= 0) { + return 0; + } else { + return static_cast(-1.) * grad; + } + } +}; + template class KLDivLossKernel : public framework::OpKernel { public: @@ -88,11 +101,10 @@ class KLDivLossGradKernel : public framework::OpKernel { auto input_grad_t = EigenVector::Flatten(*input_grad); auto loss_grad_t = EigenVector::Flatten(*loss_grad); - auto target_mask = (target_t > target_t.constant(0)).template cast(); auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand)); - input_grad_t.device(place) = - target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask; + auto grad_t = target_t * loss_grad_expand; + input_grad_t.device(place) = target_t.binaryExpr(grad_t, KLDivLossBackward()); if ("mean" == reduction) { input_grad_t.device(place) = input_grad_t / static_cast(numel); diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index b1d4e7f6ed5..d0212d177e6 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -6,8 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -21,7 +20,7 @@ from op_test import OpTest def kldiv_loss(x, target, reduction): output = target * (np.log(target) - x) - loss = np.where(target > 0, output, np.zeros_like(x)) + loss = np.where(target >= 0, output, np.zeros_like(x)) if reduction == "batchmean": return loss.sum() / x.shape[0] @@ -57,14 +56,14 @@ class TestKLDivLossOp(OpTest): ['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.06) def initTestCase(self): - self.x_shape = (3, 7, 7) - self.reduction = 'none' + self.x_shape = (2, 5, 5) + self.reduction = 'batchmean' class TestKLDivLossOp2(TestKLDivLossOp): def initTestCase(self): - self.x_shape = (2, 3, 5, 5) - self.reduction = 'batchmean' + self.x_shape = (3, 2, 7, 7) + self.reduction = 'none' class TestKLDivLossOp3(TestKLDivLossOp): -- GitLab