提交 2cde56c5 编写于 作者: W wanghaoshuang

Use Transform instead of eigen

上级 743dfd82
......@@ -80,6 +80,5 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
......@@ -14,60 +14,6 @@
#include "paddle/operators/clip_op.h"
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
namespace paddle {
namespace operators {
using framework::LoDTensor;
template <typename T>
__global__ void ClipGradientKernel(const int N, const T min, const T max,
const T* Y, const T* dY, T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
if (Y[i] > min && Y[i] < max) {
dX[i] = dY[i];
} else {
dX[i] = 0;
}
}
}
template <typename T>
class ClipGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<float>("max");
auto min = context.Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<LoDTensor>("X");
auto dims = d_x->dims();
int64_t count = d_out->numel();
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto d_out_data = d_out->data<T>();
auto x_data = x->data<T>();
int N = d_x->dims()[0];
int D = d_x->dims()[1];
int block = 512;
int grid = (N * D + block - 1) / block;
ClipGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(count, min, max, x_data, d_out_data,
d_x_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradientOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
......@@ -16,57 +16,61 @@
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
using platform::Transform;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T>
class ClipFunctor {
public:
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const {
if (x < min_)
return min_;
else if (x > max_)
return max_;
else
return x;
}
private:
T min_;
T max_;
};
template <typename T>
class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
if (y > min_ && y < max_)
return x;
else
return 0;
}
template <typename Place, typename T, size_t D>
void ClipFunction(const framework::ExecutionContext& context) {
auto max = context.op().Attr<float>("max");
auto min = context.op().Attr<float>("min");
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto place = context.GetEigenDevice<Place>();
out_tensor.device(place) = x_tensor.cwiseMin(max).cwiseMax(min);
}
private:
T min_;
T max_;
};
template <typename Place, typename T>
template <typename T>
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<LoDTensor>("X")->dims().size();
switch (rank) {
case 1:
ClipFunction<Place, T, 1>(context);
break;
case 2:
ClipFunction<Place, T, 2>(context);
break;
case 3:
ClipFunction<Place, T, 3>(context);
break;
case 4:
ClipFunction<Place, T, 4>(context);
break;
case 5:
ClipFunction<Place, T, 5>(context);
break;
case 6:
ClipFunction<Place, T, 6>(context);
break;
default:
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int numel = x->numel();
Transform(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
}
};
......@@ -74,24 +78,18 @@ template <typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.op().Attr<float>("max");
auto min = context.op().Attr<float>("min");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<LoDTensor>("X");
auto dims = d_x->dims();
int64_t count = d_out->numel();
auto* x = context.Input<Tensor>("X");
int64_t numel = d_out->numel();
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
auto d_out_data = d_out->data<T>();
auto x_data = x->data<T>();
for (int i = 0; i < count; ++i) {
if (x_data[i] > min && x_data[i] < max) {
d_x_data[i] = d_out_data[i];
} else {
d_x_data[i] = 0;
}
}
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform(context.device_context(), d_out_data, d_out_data + numel,
x_data, d_x_data, ClipGradFunctor<T>(min, max));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册