提交 12440509 编写于 作者: W wanghaoshuang

Fix some inssues

上级 c7b6d2c4
...@@ -47,10 +47,7 @@ class ClipGradFunctor { ...@@ -47,10 +47,7 @@ class ClipGradFunctor {
public: public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const { HOSTDEVICE T operator()(const T& x, const T& y) const {
if (y > min_ && y < max_) return (y > min_ && y < max_) ? x : 0;
return x;
else
return 0;
} }
private: private:
...@@ -68,7 +65,7 @@ class ClipKernel : public framework::OpKernel { ...@@ -68,7 +65,7 @@ class ClipKernel : public framework::OpKernel {
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int numel = x->numel(); int64_t numel = x->numel();
Transform<Place> trans; Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data, trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max)); ClipFunctor<T>(min, max));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册