From 51eb29de18adcf8c20272218f105eb1c2135cc09 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Mon, 29 Mar 2021 14:17:54 +0800 Subject: [PATCH] [CustomOP] Add shape related constructor for Tensor (#31681) * give shape related contructor and reshape warning * change line num to fit ut * change ut to fit * remove useless code * call resize directly in constructor --- paddle/fluid/extension/include/ext_tensor.h | 3 +++ paddle/fluid/extension/src/ext_tensor.cc | 21 ++++++++++++++++++- paddle/fluid/framework/custom_tensor_utils.h | 2 +- .../fluid/tests/custom_op/custom_relu_op.cc | 3 +-- .../custom_op/test_custom_relu_op_jit.py | 4 ++-- 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/extension/include/ext_tensor.h b/paddle/fluid/extension/include/ext_tensor.h index be492a6d553..52606b2a7f5 100644 --- a/paddle/fluid/extension/include/ext_tensor.h +++ b/paddle/fluid/extension/include/ext_tensor.h @@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor { /// \brief Construct a Tensor on target Place for CustomOp. /// Generally it's only used for user to create Tensor. explicit Tensor(const PlaceType& place); + /// \brief Construct a Tensor on target Place with shape for CustomOp. + /// Generally it's only used for user to create Tensor. + Tensor(const PlaceType& place, const std::vector& shape); /// \brief Reset the shape of the tensor. /// Generally it's only used for the input tensor. /// Reshape must be called before calling diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index 0cae8f4af7b..e9705e2101c 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -102,13 +102,32 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, void Tensor::reshape(const std::vector &shape) { GET_CASTED_TENSOR - tensor->Resize(framework::make_ddim(shape)); + auto new_dim = framework::make_ddim(shape); + if (tensor->numel() != framework::product(new_dim)) { + LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger " + "or smaller" + << "than original shape will not change your tensor's memory " + "Please call" + << "paddle::Tensor::mutable_data() after to reallocate " + "your tensor's size." + << std::endl; + } + tensor->Resize(new_dim); } Tensor::Tensor(const PlaceType &place) : tensor_(std::make_shared()), place_(place), stream_(StreamWrapper()) {} + +Tensor::Tensor(const PlaceType &place, const std::vector &shape) + : tensor_(std::make_shared()), + place_(place), + stream_(StreamWrapper()) { + GET_CASTED_TENSOR + tensor->Resize(framework::make_ddim(shape)); +} + template T *Tensor::mutable_data(const PlaceType &place) { place_ = place; diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index fad1e3ee349..809a6b965aa 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -37,7 +37,7 @@ class CustomTensorUtils { /// \brief Share data FROM another tensor. /// Use this to pass tensor from op to op /// \return void. - static void ShareDataFrom(const void* src, const Tensor& dst); + static void ShareDataFrom(const void* src, const paddle::Tensor& dst); static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( const paddle::DataType& dtype) { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index c0b30a1cb55..c075d27f7b1 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -38,9 +38,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, } std::vector relu_cpu_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU); + auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); - out.reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 23733d20841..641630b0f44 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -103,11 +103,11 @@ class TestJITLoad(unittest.TestCase): in str(e)) if IS_WINDOWS: self.assertTrue( - r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:48" + r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:47" in str(e)) else: self.assertTrue( - "python/paddle/fluid/tests/custom_op/custom_relu_op.cc:48" + "python/paddle/fluid/tests/custom_op/custom_relu_op.cc:47" in str(e)) self.assertTrue(caught_exception) -- GitLab