diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 63031c25cc3570cf40440726ea76976953d5417a..3269116c112f115e1e8fbbee0dc3b81dbe736e69 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -54,8 +54,9 @@ class PReluKernel : public framework::OpKernel { int numel = x->numel(); - Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr, - PReluFunctor(alpha_ptr)); + Transform trans; + trans(context.device_context(), x_ptr, x_ptr + numel, o_ptr, + PReluFunctor(alpha_ptr)); } }; @@ -91,8 +92,9 @@ class PReluGradKernel : public framework::OpKernel { const T* out_ptr = out->data(); int numel = dx->numel(); - Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, - dx_ptr, PReluGradFunctor(alpha_ptr)); + Transform trans; + trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr, + PReluGradFunctor(alpha_ptr)); // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready } diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index 8eaab047fd4daa386f5ebdbb99a4caeed5fe2fbf..f196868c725cbb91b3df710260c5b60f14d53f37 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -29,45 +29,71 @@ namespace paddle { namespace platform { + // Transform on host or device. It provides the same API in std library. -template -void Transform(const DeviceContext& context, InputIter first, InputIter last, - OutputIter result, UnaryOperation op) { - auto place = context.GetPlace(); - if (is_cpu_place(place)) { +template +struct Transform { + template + void operator()(const DeviceContext& context, InputIter first, InputIter last, + OutputIter result, UnaryOperation op); + + template + void operator()(const DeviceContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op); +}; + +template <> +struct Transform { + template + void operator()(const DeviceContext& context, InputIter first, InputIter last, + OutputIter result, UnaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first, last, result, op); - } else { -#ifdef __NVCC__ - auto& ctx = reinterpret_cast(context); - using namespace details; - thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first), - DevPtrCast(last), DevPtrCast(result), op); -#else - PADDLE_THROW("Do not invoke `Transform` in .cc file"); -#endif } -} -template -void Transform(const DeviceContext& context, InputIter1 first1, - InputIter1 last1, InputIter2 first2, OutputIter result, - BinaryOperation op) { - auto place = context.GetPlace(); - if (is_cpu_place(place)) { + template + void operator()(const DeviceContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first1, last1, first2, result, op); - } else { + } +}; + #ifdef __NVCC__ +template <> +struct Transform { + template + void operator()(const DeviceContext& context, InputIter first, InputIter last, + OutputIter result, UnaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); auto& ctx = reinterpret_cast(context); - using namespace details; - thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1), - DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result), + thrust::transform(thrust::cuda::par.on(ctx.stream()), + details::DevPtrCast(first), details::DevPtrCast(last), + details::DevPtrCast(result), op); + } + + template + void operator()(const DeviceContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + auto place = context.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); + auto& ctx = reinterpret_cast(context); + thrust::transform(thrust::cuda::par.on(ctx.stream()), + details::DevPtrCast(first1), details::DevPtrCast(last1), + details::DevPtrCast(first2), details::DevPtrCast(result), op); -#else - PADDLE_THROW("Do not invoke `Transform` in .cc file"); -#endif } }; +#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/transform_test.cu b/paddle/platform/transform_test.cu index b8a6200bb03c9a40b67be8d113012856e2a407e9..c76cab80e4b0e8df98a7be15f86699cfb6f93af2 100644 --- a/paddle/platform/transform_test.cu +++ b/paddle/platform/transform_test.cu @@ -15,6 +15,7 @@ #include #include "paddle/memory/memcpy.h" #include "paddle/memory/memory.h" +#include "paddle/platform/hostdevice.h" #include "paddle/platform/transform.h" template @@ -38,7 +39,8 @@ TEST(Transform, CPUUnary) { using namespace paddle::platform; CPUDeviceContext ctx; float buf[4] = {0.1, 0.2, 0.3, 0.4}; - Transform(ctx, buf, buf + 4, buf, Scale(10)); + Transform trans; + trans(ctx, buf, buf + 4, buf, Scale(10)); for (int i = 0; i < 4; ++i) { ASSERT_NEAR(buf[i], static_cast(i + 1), 1e-5); } @@ -52,7 +54,8 @@ TEST(Transform, GPUUnary) { float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; float* gpu_buf = static_cast(Alloc(gpu0, sizeof(float) * 4)); Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf)); - Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + Transform trans; + trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); ctx.Wait(); Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf)); Free(gpu0, gpu_buf); @@ -65,7 +68,9 @@ TEST(Transform, CPUBinary) { using namespace paddle::platform; using namespace paddle::memory; int buf[4] = {1, 2, 3, 4}; - Transform(CPUDeviceContext(), buf, buf + 4, buf, buf, Multiply()); + Transform trans; + CPUDeviceContext ctx; + trans(ctx, buf, buf + 4, buf, buf, Multiply()); for (int i = 0; i < 4; ++i) { ASSERT_EQ((i + 1) * (i + 1), buf[i]); } @@ -79,11 +84,12 @@ TEST(Transform, GPUBinary) { CUDADeviceContext ctx(gpu0); int* gpu_buf = static_cast(Alloc(gpu0, sizeof(buf))); Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf)); - Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + Transform trans; + trans(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); ctx.Wait(); Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf)); Free(gpu0, gpu_buf); for (int i = 0; i < 4; ++i) { ASSERT_EQ((i + 1) * (i + 1), buf[i]); } -} \ No newline at end of file +} diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 76d1f1d5a418b7a2a91b36360a79317d063a72e7..2b6b7db36808a4b68c55328a1eb9ac212c18b678 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -17,10 +17,10 @@ class PReluTest(OpTest): assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} - def not_test_check_output(self): + def test_check_output(self): self.check_output() - def not_test_check_grad(self): + def test_check_grad(self): self.check_grad(['X'], 'Out')