提交 ebcb7a7a 编写于 作者: D dengkaipeng

fix grad check. test=develop

上级 3e3a983a
......@@ -81,7 +81,7 @@ class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker {
"The reduction type to apply to the output, available types "
"are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no "
"reduction, 'batchmean' for the sum of output divided by "
"batch size, 'mean' for the average valud of all output, "
"batchmean size, 'mean' for the average valud of all output, "
"'sum' for the sum of the output.")
.SetDefault("mean");
......
......@@ -13,9 +13,10 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
sum, ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
kldiv_loss,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sum_grad,
kldiv_loss_grad,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -54,13 +54,12 @@ class KLDivLossKernel : public framework::OpKernel<T> {
auto input_t = EigenVector<T>::Flatten(*input);
auto target_t = EigenVector<T>::Flatten(*target);
auto loss_t = EigenVector<T>::Flatten(*loss);
// auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
// auto output = (target_t * (target_t.log() - input_t)) * target_mask;
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
loss_t.device(place) = output.sum() / static_cast<T>(n);
auto output_sum = output.sum().eval();
loss_t.device(place) = output_sum / output_sum.constant(n);
} else if ("mean" == reduction) {
loss_t.device(place) = output.mean();
} else if ("sum" == reduction) {
......@@ -74,19 +73,17 @@ class KLDivLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* input = ctx.Input<Tensor>("X");
auto* target = ctx.Input<Tensor>("Target");
auto reduction = ctx.Attr<std::string>("reduction");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
const int n = input->dims()[0];
const int numel = input->numel();
const int n = input_grad->dims()[0];
const int numel = input_grad->numel();
const int expand = numel / loss_grad->numel();
input_grad->mutable_data<T>(ctx.GetPlace());
auto input_t = EigenVector<T>::Flatten(*input);
auto target_t = EigenVector<T>::Flatten(*target);
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
......@@ -96,14 +93,6 @@ class KLDivLossGradKernel : public framework::OpKernel<T> {
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
input_grad_t.device(place) =
target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask;
// if (reduction == "none") {
// input_grad_t.device(place) =
// target_t * loss_grad_t * target_t.constant(-1.0);
// } else {
// auto loss_grad_expand = loss_grad_t.broadcast(Array1(numel));
// input_grad_t.device(place) =
// target_t * loss_grad_expand * target_t.constant(-1.0);
// }
if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
......
......@@ -47,36 +47,37 @@ class TestKLDivLossOp(OpTest):
'Target': target,
}
loss = kldiv_loss(x, target, self.reduction)
self.outputs = {'Loss': loss}
self.outputs = {'Loss': loss.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.1)
['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.06)
def initTestCase(self):
self.x_shape = (3, 7, 7)
self.reduction = 'none'
class TestKLDivLossOp2(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 5)
self.reduction = 'batchmean'
# class TestKLDivLossOp2(TestKLDivLossOp):
# def initTestCase(self):
# self.x_shape = (3, 7, 7)
# self.reduction = 'batchmean'
#
#
# class TestKLDivLossOp3(TestKLDivLossOp):
# def initTestCase(self):
# self.x_shape = (2, 3, 5, 7, 9)
# self.reduction = 'mean'
#
#
# class TestKLDivLossOp4(TestKLDivLossOp):
# def initTestCase(self):
# self.x_shape = (5, 7)
# self.reduction = 'sum'
class TestKLDivLossOp3(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 7, 9)
self.reduction = 'mean'
class TestKLDivLossOp4(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (5, 7)
self.reduction = 'sum'
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册