提交 e90e0bdf 编写于 作者: D dengkaipeng

fix for gpu grad. test=develop

上级 ebcb7a7a
......@@ -33,7 +33,7 @@ class KLDivLossOp : public framework::OperatorWithKernel {
auto dim_target = ctx->GetInputDim("Target");
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same.");
for (size_t i = 0; i < dim_x.size(); i++) {
for (int i = 0; i < dim_x.size(); i++) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
"Input(X) and Input(Target) should in same shape.");
}
......
......@@ -30,7 +30,7 @@ struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target < 0) {
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
......@@ -38,6 +38,19 @@ struct KLDivLossForward {
}
};
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * grad;
}
}
};
template <typename DeviceContext, typename T>
class KLDivLossKernel : public framework::OpKernel<T> {
public:
......@@ -88,11 +101,10 @@ class KLDivLossGradKernel : public framework::OpKernel<T> {
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
auto loss_grad_t = EigenVector<T>::Flatten(*loss_grad);
auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
input_grad_t.device(place) =
target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask;
auto grad_t = target_t * loss_grad_expand;
input_grad_t.device(place) = target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
......
......@@ -6,8 +6,7 @@
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -21,7 +20,7 @@ from op_test import OpTest
def kldiv_loss(x, target, reduction):
output = target * (np.log(target) - x)
loss = np.where(target > 0, output, np.zeros_like(x))
loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean":
return loss.sum() / x.shape[0]
......@@ -57,14 +56,14 @@ class TestKLDivLossOp(OpTest):
['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.06)
def initTestCase(self):
self.x_shape = (3, 7, 7)
self.reduction = 'none'
self.x_shape = (2, 5, 5)
self.reduction = 'batchmean'
class TestKLDivLossOp2(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 5)
self.reduction = 'batchmean'
self.x_shape = (3, 2, 7, 7)
self.reduction = 'none'
class TestKLDivLossOp3(TestKLDivLossOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册