diff --git a/paddle/fluid/extension/include/ext_tensor.h b/paddle/fluid/extension/include/ext_tensor.h index be492a6d5535d17df579ec4fec8dd76d266a3029..52606b2a7f59e0b3340c2df4b641211511529240 100644 --- a/paddle/fluid/extension/include/ext_tensor.h +++ b/paddle/fluid/extension/include/ext_tensor.h @@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor { /// \brief Construct a Tensor on target Place for CustomOp. /// Generally it's only used for user to create Tensor. explicit Tensor(const PlaceType& place); + /// \brief Construct a Tensor on target Place with shape for CustomOp. + /// Generally it's only used for user to create Tensor. + Tensor(const PlaceType& place, const std::vector& shape); /// \brief Reset the shape of the tensor. /// Generally it's only used for the input tensor. /// Reshape must be called before calling diff --git a/paddle/fluid/extension/src/ext_tensor.cc b/paddle/fluid/extension/src/ext_tensor.cc index 0cae8f4af7b97de19d4daaad5422fd866ff0124a..e9705e2101cc3cedd99eddbe6e31caf2dcca68bd 100644 --- a/paddle/fluid/extension/src/ext_tensor.cc +++ b/paddle/fluid/extension/src/ext_tensor.cc @@ -102,13 +102,32 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, void Tensor::reshape(const std::vector &shape) { GET_CASTED_TENSOR - tensor->Resize(framework::make_ddim(shape)); + auto new_dim = framework::make_ddim(shape); + if (tensor->numel() != framework::product(new_dim)) { + LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger " + "or smaller" + << "than original shape will not change your tensor's memory " + "Please call" + << "paddle::Tensor::mutable_data() after to reallocate " + "your tensor's size." + << std::endl; + } + tensor->Resize(new_dim); } Tensor::Tensor(const PlaceType &place) : tensor_(std::make_shared()), place_(place), stream_(StreamWrapper()) {} + +Tensor::Tensor(const PlaceType &place, const std::vector &shape) + : tensor_(std::make_shared()), + place_(place), + stream_(StreamWrapper()) { + GET_CASTED_TENSOR + tensor->Resize(framework::make_ddim(shape)); +} + template T *Tensor::mutable_data(const PlaceType &place) { place_ = place; diff --git a/paddle/fluid/framework/custom_tensor_utils.h b/paddle/fluid/framework/custom_tensor_utils.h index fad1e3ee3496cd410e4fca77c09f61c6ff53a402..809a6b965aad9bcb4594ecff99e460db723dfd53 100644 --- a/paddle/fluid/framework/custom_tensor_utils.h +++ b/paddle/fluid/framework/custom_tensor_utils.h @@ -37,7 +37,7 @@ class CustomTensorUtils { /// \brief Share data FROM another tensor. /// Use this to pass tensor from op to op /// \return void. - static void ShareDataFrom(const void* src, const Tensor& dst); + static void ShareDataFrom(const void* src, const paddle::Tensor& dst); static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( const paddle::DataType& dtype) { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index c0b30a1cb5579c166845b8184ccb65bf515518ba..c075d27f7b1763babf54cd9d378b28f29fd84566 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -38,9 +38,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, } std::vector relu_cpu_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kCPU); + auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); - out.reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 23733d20841b3afa2347fb32fcc4335491b7f8cc..641630b0f4476a04752c20170a863ad8edd25c9a 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -103,11 +103,11 @@ class TestJITLoad(unittest.TestCase): in str(e)) if IS_WINDOWS: self.assertTrue( - r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:48" + r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:47" in str(e)) else: self.assertTrue( - "python/paddle/fluid/tests/custom_op/custom_relu_op.cc:48" + "python/paddle/fluid/tests/custom_op/custom_relu_op.cc:47" in str(e)) self.assertTrue(caught_exception)