提交 3c3a6d90 编写于 作者: Z zchen0211

prelu finalize

上级 1b797468
......@@ -26,17 +26,17 @@ using platform::Transform;
template <typename T>
class PReluFunctor {
public:
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
explicit PReluFunctor(const T* alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& x) const {
if (x > 0)
return x;
else
return x * alpha_;
return x * (*alpha_);
}
private:
T alpha_;
const T* alpha_;
};
template <typename Place, typename T>
......@@ -50,30 +50,29 @@ class PReluKernel : public framework::OpKernel {
const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace());
auto alpha_val = alpha->data<T>()[0];
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
auto* alpha_ptr = alpha->data<T>();
int numel = x->numel();
auto place = context.GetPlace();
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val));
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr));
}
};
template <typename T>
class PReluGradFunctor {
public:
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& out, const T& dout) const {
if (out > 0)
return dout;
else
return dout * alpha_;
return dout * (*alpha_);
}
private:
T alpha_;
const T* alpha_;
};
template <typename Place, typename T>
......@@ -85,7 +84,7 @@ class PReluGradKernel : public framework::OpKernel {
auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha");
auto alpha_val = alpha->data<T>()[0];
auto* alpha_ptr = alpha->data<T>();
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
const T* dout_ptr = dout->data<T>();
......@@ -94,7 +93,9 @@ class PReluGradKernel : public framework::OpKernel {
auto place = context.GetPlace();
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha_val));
PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
}
};
......
......@@ -3,13 +3,15 @@ import numpy as np
from op_test import OpTest
class PreluTest(OpTest):
class PReluTest(OpTest):
def setUp(self):
self.op_type = "prelu"
self.inputs = {'X': np.random.normal(size=(10, 10)).astype("float32")}
self.attrs = {'alpha': 0.1}
x_np = np.random.normal(size=(10, 10)).astype("float32")
alpha_np = np.array([.1])
self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha']
out_np = out_np + np.minimum(self.inputs['X'],
0.) * self.inputs['Alpha']
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册