From bc1c3e3ef698126f8724933f92651fc0a4ed29f3 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 18 Apr 2022 20:24:13 +0800 Subject: [PATCH] Create Tensor by paddle::empty in custom operator (#41840) * create tensor by empty in custom op * fix some bug --- .../final_state_generator/codegen_utils.py | 2 +- .../final_state_generator/python_c_gen.py | 2 +- paddle/fluid/pybind/eager_utils.cc | 10 ++++------ paddle/fluid/pybind/eager_utils.h | 10 ++++------ paddle/phi/common/place.h | 3 --- paddle/phi/tests/api/test_data_transform.cc | 19 ++++++++----------- paddle/phi/tests/api/test_scale_benchmark.cc | 2 +- .../tests/custom_op/context_pool_test_op.cc | 6 ++---- .../fluid/tests/custom_op/custom_concat_op.cc | 8 ++++---- .../fluid/tests/custom_op/custom_conj_op.cc | 2 +- .../fluid/tests/custom_op/custom_relu_op.cc | 8 ++++---- .../fluid/tests/custom_op/custom_relu_op.cu | 8 ++++---- .../fluid/tests/custom_op/custom_tanh_op.cc | 8 ++++---- 13 files changed, 38 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index ab8c28c33e7..7769c5371ba 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -45,7 +45,7 @@ yaml_types_mapping = { 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'str' : 'std::string', \ - 'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ + 'Place' : 'paddle::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', 'Tensor' : 'Tensor', 'Tensor[]' : 'std::vector', diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 8075b65b194..e2bb4104551 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -46,7 +46,7 @@ atype_to_parsing_function = { "std::vector": "CastPyArg2Strings", "paddle::experimental::Scalar": "CastPyArg2Scalar", "paddle::experimental::IntArray": "CastPyArg2IntArray", - "paddle::experimental::Place": "CastPyArg2Place", + "paddle::Place": "CastPyArg2Place", "paddle::experimental::DataType": "CastPyArg2DataType", } diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 8fa21ef45f8..081e2783826 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1151,15 +1151,13 @@ std::vector GetScopePtrListFromArgs( return result; } -paddle::experimental::Place CastPyArg2Place(PyObject* obj, - const std::string& op_type, - ssize_t arg_pos) { +paddle::Place CastPyArg2Place(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { return CastPyArg2Place(obj, arg_pos); } -paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, - const std::string& op_type, - ssize_t arg_pos) { +paddle::DataType CastPyArg2DataType(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (obj == Py_None) { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 90c4d727923..e0ad69871a1 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -162,13 +162,11 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, const std::string& op_type, ssize_t arg_pos); -paddle::experimental::Place CastPyArg2Place(PyObject* obj, - const std::string& op_type, - ssize_t arg_pos); +paddle::Place CastPyArg2Place(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); -paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, - const std::string& op_type, - ssize_t arg_pos); +paddle::DataType CastPyArg2DataType(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); paddle::optional GetOptionalTensorFromArgs( const std::string& op_type, const std::string& arg_name, PyObject* args, diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index ed9fb787642..199ee81f272 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -213,9 +213,6 @@ std::ostream& operator<<(std::ostream&, const Place&); namespace paddle { namespace experimental { using AllocationType = phi::AllocationType; -using Place = phi::Place; -using CPUPlace = phi::CPUPlace; -using GPUPlace = phi::GPUPlace; using GPUPinnedPlace = phi::GPUPinnedPlace; using XPUPlace = phi::XPUPlace; using NPUPlace = phi::NPUPlace; diff --git a/paddle/phi/tests/api/test_data_transform.cc b/paddle/phi/tests/api/test_data_transform.cc index a2bd1f2cad9..21d5eef4098 100644 --- a/paddle/phi/tests/api/test_data_transform.cc +++ b/paddle/phi/tests/api/test_data_transform.cc @@ -37,13 +37,11 @@ namespace tests { // TODO(chenweihang): Remove this test after the API is used in the dygraph TEST(API, data_transform_same_place) { // 1. create tensor - auto x = paddle::experimental::full({3, 3}, - 1.0, - experimental::DataType::COMPLEX128, - experimental::CPUPlace()); + auto x = + paddle::experimental::full({3, 3}, 1.0, DataType::COMPLEX128, CPUPlace()); - auto y = paddle::experimental::full( - {3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::CPUPlace()); + auto y = + paddle::experimental::full({3, 3}, 2.0, DataType::FLOAT32, CPUPlace()); std::vector> sum(9, 6.0); @@ -75,10 +73,10 @@ TEST(API, data_transform_same_place) { TEST(Tensor, data_transform_diff_place) { // 1. create tensor auto x = paddle::experimental::full( - {3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::CPUPlace()); + {3, 3}, 1.0, experimental::DataType::FLOAT64, CPUPlace()); auto y = paddle::experimental::full( - {3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::GPUPlace()); + {3, 3}, 2.0, experimental::DataType::FLOAT64, GPUPlace()); std::vector sum(9, 6.0); @@ -93,10 +91,9 @@ TEST(Tensor, data_transform_diff_place) { ASSERT_EQ(out.dtype(), phi::DataType::FLOAT64); ASSERT_EQ(out.layout(), phi::DataLayout::NCHW); ASSERT_EQ(out.initialized(), true); - ASSERT_EQ(out.impl()->place(), - phi::TransToPhiPlace(experimental::Backend::GPU)); + ASSERT_EQ(out.impl()->place(), phi::TransToPhiPlace(phi::Backend::GPU)); - auto ref_out = experimental::copy_to(out, experimental::CPUPlace(), true); + auto ref_out = experimental::copy_to(out, CPUPlace(), true); auto dense_out = std::dynamic_pointer_cast(ref_out.impl()); for (size_t i = 0; i < 9; i++) { diff --git a/paddle/phi/tests/api/test_scale_benchmark.cc b/paddle/phi/tests/api/test_scale_benchmark.cc index ca4a264e511..e2870a780ae 100644 --- a/paddle/phi/tests/api/test_scale_benchmark.cc +++ b/paddle/phi/tests/api/test_scale_benchmark.cc @@ -30,7 +30,7 @@ namespace tests { TEST(API, scale) { auto x = experimental::full( - {3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::CPUPlace()); + {3, 4}, 1.0, experimental::DataType::FLOAT32, CPUPlace()); const size_t cycles = 300; phi::tests::Timer timer; diff --git a/python/paddle/fluid/tests/custom_op/context_pool_test_op.cc b/python/paddle/fluid/tests/custom_op/context_pool_test_op.cc index 6b0edcc7ab1..9286ae7ca00 100644 --- a/python/paddle/fluid/tests/custom_op/context_pool_test_op.cc +++ b/python/paddle/fluid/tests/custom_op/context_pool_test_op.cc @@ -22,8 +22,7 @@ std::vector ContextPoolTest(const paddle::Tensor& x) { // 1. test cpu context - paddle::experimental::Place cpu_place( - paddle::experimental::AllocationType::CPU); + paddle::Place cpu_place(paddle::experimental::AllocationType::CPU); auto* cpu_ctx = paddle::experimental::DeviceContextPool::Instance() .Get(cpu_place); @@ -34,8 +33,7 @@ std::vector ContextPoolTest(const paddle::Tensor& x) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // 2. test gpu context - paddle::experimental::Place gpu_place( - paddle::experimental::AllocationType::GPU); + paddle::Place gpu_place(paddle::experimental::AllocationType::GPU); auto* gpu_ctx = paddle::experimental::DeviceContextPool::Instance() .Get(gpu_place); diff --git a/python/paddle/fluid/tests/custom_op/custom_concat_op.cc b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc index 66cc36c300e..80f76e2df54 100644 --- a/python/paddle/fluid/tests/custom_op/custom_concat_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_concat_op.cc @@ -75,7 +75,7 @@ std::vector ConcatForwardDynamicAxis( auto out_shape = ComputeOutShape(in_shapes, axis); // create output - auto out = paddle::Tensor(paddle::PlaceType::kCPU, out_shape); + auto out = paddle::empty(out_shape, inputs[0].type(), paddle::CPUPlace()); // calc PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( @@ -106,7 +106,7 @@ std::vector ConcatBackwardDynamicAxis( // create outputs std::vector grad_inputs; for (auto& t : inputs) { - auto grad = paddle::Tensor(paddle::PlaceType::kCPU, t.shape()); + auto grad = paddle::empty(t.shape(), t.dtype(), t.place()); grad_inputs.emplace_back(grad); } @@ -161,7 +161,7 @@ std::vector ConcatForwardStaticAxis( auto out_shape = ComputeOutShape(in_shapes, final_axis); // create output - auto out = paddle::Tensor(paddle::PlaceType::kCPU, out_shape); + auto out = paddle::empty(out_shape, inputs[0].type(), paddle::CPUPlace()); // calc PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES( @@ -190,7 +190,7 @@ std::vector ConcatBackwardStaticAxis( // create outputs std::vector grad_inputs; for (auto& t : inputs) { - auto grad = paddle::Tensor(paddle::PlaceType::kCPU, t.shape()); + auto grad = paddle::empty(t.shape(), t.dtype(), t.place()); grad_inputs.emplace_back(grad); } diff --git a/python/paddle/fluid/tests/custom_op/custom_conj_op.cc b/python/paddle/fluid/tests/custom_op/custom_conj_op.cc index b9c10f479e0..56938552420 100644 --- a/python/paddle/fluid/tests/custom_op/custom_conj_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_conj_op.cc @@ -71,7 +71,7 @@ void ConjCPUKernel(const data_t* x_data, int64_t numel, data_t* out_data) { std::vector ConjFunction(const paddle::Tensor& x) { CHECK_INPUT(x); - paddle::Tensor out(x.place(), x.shape()); + paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place()); PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES( x.type(), "ConjCPUKernel", ([&] { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index 121a855a18f..04399a9826c 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -54,7 +54,7 @@ void relu_cpu_double_backward_kernel(const data_t* out_data, } std::vector relu_cpu_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto out = paddle::empty(x.shape(), x.dtype(), x.place()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { @@ -68,7 +68,7 @@ std::vector relu_cpu_forward(const paddle::Tensor& x) { std::vector relu_cpu_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { relu_cpu_backward_kernel( @@ -85,7 +85,7 @@ std::vector relu_cpu_double_backward( const paddle::Tensor& out, const paddle::Tensor& ddx) { CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(ddx); - auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] { relu_cpu_double_backward_kernel( @@ -165,7 +165,7 @@ PD_BUILD_DOUBLE_GRAD_OP(custom_relu) std::vector relu_cpu_backward_without_x( const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { relu_cpu_backward_kernel( diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 364a2216b9e..18f1a2b95c2 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -54,7 +54,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data, std::vector relu_cuda_forward(const paddle::Tensor& x) { CHECK_GPU_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto out = paddle::empty(x.shape(), x.dtype(), x.place()); int numel = x.size(); int block = 512; @@ -74,7 +74,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, CHECK_GPU_INPUT(x); CHECK_GPU_INPUT(out); CHECK_GPU_INPUT(grad_out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place()); int numel = out.size(); int block = 512; @@ -95,7 +95,7 @@ std::vector relu_cuda_double_backward( const paddle::Tensor& out, const paddle::Tensor& ddx) { CHECK_GPU_INPUT(out); CHECK_GPU_INPUT(ddx); - auto ddout = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); int64_t numel = out.size(); int64_t block = 512; @@ -117,7 +117,7 @@ std::vector relu_cuda_double_backward( std::vector relu_cuda_backward_without_x( const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); + auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); int numel = out.size(); int block = 512; diff --git a/python/paddle/fluid/tests/custom_op/custom_tanh_op.cc b/python/paddle/fluid/tests/custom_op/custom_tanh_op.cc index f96297d69bd..399eb5b6366 100644 --- a/python/paddle/fluid/tests/custom_op/custom_tanh_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_tanh_op.cc @@ -68,7 +68,7 @@ void tanh_cpu_double_backward_kernel(const data_t* out_data, std::vector TanhForward(const paddle::Tensor& x) { CHECK_CPU_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto out = paddle::empty(x.shape(), x.dtype(), x.place()); PD_DISPATCH_FLOATING_TYPES( x.dtype(), "tanh_cpu_forward", ([&] { @@ -82,7 +82,7 @@ std::vector TanhForward(const paddle::Tensor& x) { std::vector TanhBackward(const paddle::Tensor& out, const paddle::Tensor& grad_out) { CHECK_CPU_INPUT(out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place()); PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_backward", ([&] { tanh_cpu_backward_kernel( @@ -101,8 +101,8 @@ std::vector TanhDoubleBackward(const paddle::Tensor& out, CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(ddx); CHECK_CPU_INPUT(dout); - auto dout_new = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); - auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + auto dout_new = paddle::empty(out.shape(), out.dtype(), out.place()); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_double_backward", ([&] { tanh_cpu_double_backward_kernel( -- GitLab