提交 3c3a6d90 编写于 作者: Z zchen0211

prelu finalize

上级 1b797468
...@@ -26,17 +26,17 @@ using platform::Transform; ...@@ -26,17 +26,17 @@ using platform::Transform;
template <typename T> template <typename T>
class PReluFunctor { class PReluFunctor {
public: public:
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {} explicit PReluFunctor(const T* alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& x) const { HOSTDEVICE T operator()(const T& x) const {
if (x > 0) if (x > 0)
return x; return x;
else else
return x * alpha_; return x * (*alpha_);
} }
private: private:
T alpha_; const T* alpha_;
}; };
template <typename Place, typename T> template <typename Place, typename T>
...@@ -50,30 +50,29 @@ class PReluKernel : public framework::OpKernel { ...@@ -50,30 +50,29 @@ class PReluKernel : public framework::OpKernel {
const T* x_ptr = x->data<T>(); const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace()); T* o_ptr = out->mutable_data<T>(context.GetPlace());
auto alpha_val = alpha->data<T>()[0]; auto* alpha_ptr = alpha->data<T>();
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
int numel = x->numel(); int numel = x->numel();
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val)); Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr));
} }
}; };
template <typename T> template <typename T>
class PReluGradFunctor { class PReluGradFunctor {
public: public:
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {} explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& out, const T& dout) const { HOSTDEVICE T operator()(const T& out, const T& dout) const {
if (out > 0) if (out > 0)
return dout; return dout;
else else
return dout * alpha_; return dout * (*alpha_);
} }
private: private:
T alpha_; const T* alpha_;
}; };
template <typename Place, typename T> template <typename Place, typename T>
...@@ -85,7 +84,7 @@ class PReluGradKernel : public framework::OpKernel { ...@@ -85,7 +84,7 @@ class PReluGradKernel : public framework::OpKernel {
auto* out = context.Input<Tensor>("Out"); auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha"); auto* alpha = context.Input<Tensor>("Alpha");
auto alpha_val = alpha->data<T>()[0]; auto* alpha_ptr = alpha->data<T>();
T* dx_ptr = dx->mutable_data<T>(context.GetPlace()); T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
const T* dout_ptr = dout->data<T>(); const T* dout_ptr = dout->data<T>();
...@@ -94,7 +93,9 @@ class PReluGradKernel : public framework::OpKernel { ...@@ -94,7 +93,9 @@ class PReluGradKernel : public framework::OpKernel {
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr, Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha_val)); PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
} }
}; };
......
...@@ -3,13 +3,15 @@ import numpy as np ...@@ -3,13 +3,15 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
class PreluTest(OpTest): class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.op_type = "prelu" self.op_type = "prelu"
self.inputs = {'X': np.random.normal(size=(10, 10)).astype("float32")} x_np = np.random.normal(size=(10, 10)).astype("float32")
self.attrs = {'alpha': 0.1} alpha_np = np.array([.1])
self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha'] out_np = out_np + np.minimum(self.inputs['X'],
0.) * self.inputs['Alpha']
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册