提交 1fdad1a6 编写于 作者: W wanghaoshuang

Update transform invocation

上级 3f3848cd
......@@ -80,5 +80,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
......@@ -15,5 +15,7 @@
#include "paddle/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
......@@ -58,7 +58,7 @@ class ClipGradFunctor {
T max_;
};
template <typename T>
template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -69,12 +69,13 @@ class ClipKernel : public framework::OpKernel {
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int numel = x->numel();
Transform(context.device_context(), x_data, x_data + numel, out_data,
Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
}
};
template <typename T>
template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -88,8 +89,9 @@ class ClipGradKernel : public framework::OpKernel {
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform(context.device_context(), d_out_data, d_out_data + numel,
x_data, d_x_data, ClipGradFunctor<T>(min, max));
Transform<Place> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册