提交 6ca37448 编写于 作者: Y Yu Yang

Refine prelu_op

上级 ae9378f6
......@@ -32,7 +32,7 @@ class PReluKernel : public framework::OpKernel<T> {
T* o_ptr = out->mutable_data<T>(context.GetPlace());
const T* alpha_ptr = alpha->data<T>();
std::string mode = context.Attr<std::string>("mode");
auto& mode = context.Attr<std::string>("mode");
int numel = x->numel();
auto dim = x->dims();
......@@ -99,6 +99,8 @@ class PReluGradKernel : public framework::OpKernel<T> {
index = 0;
if (dalpha) {
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") {
for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <Python.h>
#include <cmake-build-release/third_party/pybind/src/extern_pybind/include/pybind11/common.h>
#include <string>
#include <tuple>
#include <vector>
......@@ -22,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "pybind11/common.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册