diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index f88ce94dc865f4ca8ddb9e20c22136ee384bdfe2..ece2a836a65e6508580bc32b84f7833388ce55f3 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -26,17 +26,17 @@ using platform::Transform; template class PReluFunctor { public: - explicit PReluFunctor(const T& alpha) : alpha_(alpha) {} + explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} HOSTDEVICE T operator()(const T& x) const { if (x > 0) return x; else - return x * alpha_; + return x * (*alpha_); } private: - T alpha_; + const T* alpha_; }; template @@ -50,30 +50,29 @@ class PReluKernel : public framework::OpKernel { const T* x_ptr = x->data(); T* o_ptr = out->mutable_data(context.GetPlace()); - auto alpha_val = alpha->data()[0]; - // auto alpha = static_cast(context.Attr("alpha")); + auto* alpha_ptr = alpha->data(); int numel = x->numel(); auto place = context.GetPlace(); - Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor(alpha_val)); + Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor(alpha_ptr)); } }; template class PReluGradFunctor { public: - explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {} + explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} HOSTDEVICE T operator()(const T& out, const T& dout) const { if (out > 0) return dout; else - return dout * alpha_; + return dout * (*alpha_); } private: - T alpha_; + const T* alpha_; }; template @@ -85,7 +84,7 @@ class PReluGradKernel : public framework::OpKernel { auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); - auto alpha_val = alpha->data()[0]; + auto* alpha_ptr = alpha->data(); T* dx_ptr = dx->mutable_data(context.GetPlace()); const T* dout_ptr = dout->data(); @@ -94,7 +93,9 @@ class PReluGradKernel : public framework::OpKernel { auto place = context.GetPlace(); Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr, - PReluGradFunctor(alpha_val)); + PReluGradFunctor(alpha_ptr)); + + // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready } }; diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index cbf2e6b2a88ce15ecb513dbf1f7973ce12930235..b74812e96917a28a5aa5d7dd3113e3e958c92cdc 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -3,13 +3,15 @@ import numpy as np from op_test import OpTest -class PreluTest(OpTest): +class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" - self.inputs = {'X': np.random.normal(size=(10, 10)).astype("float32")} - self.attrs = {'alpha': 0.1} + x_np = np.random.normal(size=(10, 10)).astype("float32") + alpha_np = np.array([.1]) + self.inputs = {'X': x_np, 'Alpha': alpha_np} out_np = np.maximum(self.inputs['X'], 0.) - out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha'] + out_np = out_np + np.minimum(self.inputs['X'], + 0.) * self.inputs['Alpha'] assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np}