提交 1fdad1a6 编写于 作者: W wanghaoshuang

Update transform invocation

上级 3f3848cd
...@@ -80,5 +80,7 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -80,5 +80,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad, REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad); ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip, ops::ClipKernel<float>); REGISTER_OP_CPU_KERNEL(clip,
REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel<float>); ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
...@@ -15,5 +15,7 @@ ...@@ -15,5 +15,7 @@
#include "paddle/operators/clip_op.h" #include "paddle/operators/clip_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip, ops::ClipKernel<float>); REGISTER_OP_GPU_KERNEL(clip,
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradKernel<float>); ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
...@@ -58,7 +58,7 @@ class ClipGradFunctor { ...@@ -58,7 +58,7 @@ class ClipGradFunctor {
T max_; T max_;
}; };
template <typename T> template <typename Place, typename T>
class ClipKernel : public framework::OpKernel { class ClipKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -69,12 +69,13 @@ class ClipKernel : public framework::OpKernel { ...@@ -69,12 +69,13 @@ class ClipKernel : public framework::OpKernel {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int numel = x->numel(); int numel = x->numel();
Transform(context.device_context(), x_data, x_data + numel, out_data, Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max)); ClipFunctor<T>(min, max));
} }
}; };
template <typename T> template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel { class ClipGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -88,8 +89,9 @@ class ClipGradKernel : public framework::OpKernel { ...@@ -88,8 +89,9 @@ class ClipGradKernel : public framework::OpKernel {
auto d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>(); const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
Transform(context.device_context(), d_out_data, d_out_data + numel, Transform<Place> trans;
x_data, d_x_data, ClipGradFunctor<T>(min, max)); trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册