From 6ca37448acc17719f633af515f553a475c0842db Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 30 Sep 2018 12:20:12 +0800 Subject: [PATCH] Refine prelu_op --- paddle/fluid/operators/prelu_op.h | 4 +++- paddle/fluid/pybind/tensor_py.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index 12f1525594..594f1cb3ab 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -32,7 +32,7 @@ class PReluKernel : public framework::OpKernel { T* o_ptr = out->mutable_data(context.GetPlace()); const T* alpha_ptr = alpha->data(); - std::string mode = context.Attr("mode"); + auto& mode = context.Attr("mode"); int numel = x->numel(); auto dim = x->dims(); @@ -99,6 +99,8 @@ class PReluGradKernel : public framework::OpKernel { index = 0; if (dalpha) { T* dalpha_ptr = dalpha->mutable_data(context.GetPlace()); + memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); + if (mode == "channel") { for (i = 0; i < numel; i++) { temp = numel / (dim[0] * dim[1]); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 76ff1acacb..0e5fd97951 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once #include -#include #include #include #include @@ -22,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" +#include "pybind11/common.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" -- GitLab