提交 6ca37448 编写于 作者: Y Yu Yang

Refine prelu_op

上级 ae9378f6
...@@ -32,7 +32,7 @@ class PReluKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class PReluKernel : public framework::OpKernel<T> {
T* o_ptr = out->mutable_data<T>(context.GetPlace()); T* o_ptr = out->mutable_data<T>(context.GetPlace());
const T* alpha_ptr = alpha->data<T>(); const T* alpha_ptr = alpha->data<T>();
std::string mode = context.Attr<std::string>("mode"); auto& mode = context.Attr<std::string>("mode");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims(); auto dim = x->dims();
...@@ -99,6 +99,8 @@ class PReluGradKernel : public framework::OpKernel<T> { ...@@ -99,6 +99,8 @@ class PReluGradKernel : public framework::OpKernel<T> {
index = 0; index = 0;
if (dalpha) { if (dalpha) {
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace()); T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") { if (mode == "channel") {
for (i = 0; i < numel; i++) { for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]); temp = numel / (dim[0] * dim[1]);
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include <cmake-build-release/third_party/pybind/src/extern_pybind/include/pybind11/common.h>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -22,6 +21,7 @@ limitations under the License. */ ...@@ -22,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "pybind11/common.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册