未验证 提交 51eb29de 编写于 作者: J Jiabin Yang 提交者: GitHub

[CustomOP] Add shape related constructor for Tensor (#31681)

* give shape related contructor and reshape warning

* change line num to fit ut

* change ut to fit

* remove useless code

* call resize directly in constructor
上级 e3a38d79
...@@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor { ...@@ -52,6 +52,9 @@ class PD_DLL_DECL Tensor {
/// \brief Construct a Tensor on target Place for CustomOp. /// \brief Construct a Tensor on target Place for CustomOp.
/// Generally it's only used for user to create Tensor. /// Generally it's only used for user to create Tensor.
explicit Tensor(const PlaceType& place); explicit Tensor(const PlaceType& place);
/// \brief Construct a Tensor on target Place with shape for CustomOp.
/// Generally it's only used for user to create Tensor.
Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
/// \brief Reset the shape of the tensor. /// \brief Reset the shape of the tensor.
/// Generally it's only used for the input tensor. /// Generally it's only used for the input tensor.
/// Reshape must be called before calling /// Reshape must be called before calling
......
...@@ -102,13 +102,32 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, ...@@ -102,13 +102,32 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
void Tensor::reshape(const std::vector<int64_t> &shape) { void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape)); auto new_dim = framework::make_ddim(shape);
if (tensor->numel() != framework::product(new_dim)) {
LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger "
"or smaller"
<< "than original shape will not change your tensor's memory "
"Please call"
<< "paddle::Tensor::mutable_data<T>() after to reallocate "
"your tensor's size."
<< std::endl;
}
tensor->Resize(new_dim);
} }
Tensor::Tensor(const PlaceType &place) Tensor::Tensor(const PlaceType &place)
: tensor_(std::make_shared<framework::LoDTensor>()), : tensor_(std::make_shared<framework::LoDTensor>()),
place_(place), place_(place),
stream_(StreamWrapper()) {} stream_(StreamWrapper()) {}
Tensor::Tensor(const PlaceType &place, const std::vector<int64_t> &shape)
: tensor_(std::make_shared<framework::LoDTensor>()),
place_(place),
stream_(StreamWrapper()) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
}
template <typename T> template <typename T>
T *Tensor::mutable_data(const PlaceType &place) { T *Tensor::mutable_data(const PlaceType &place) {
place_ = place; place_ = place;
......
...@@ -37,7 +37,7 @@ class CustomTensorUtils { ...@@ -37,7 +37,7 @@ class CustomTensorUtils {
/// \brief Share data FROM another tensor. /// \brief Share data FROM another tensor.
/// Use this to pass tensor from op to op /// Use this to pass tensor from op to op
/// \return void. /// \return void.
static void ShareDataFrom(const void* src, const Tensor& dst); static void ShareDataFrom(const void* src, const paddle::Tensor& dst);
static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType( static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
const paddle::DataType& dtype) { const paddle::DataType& dtype) {
......
...@@ -38,9 +38,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, ...@@ -38,9 +38,8 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
} }
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) { std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU); auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] { x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>( relu_cpu_forward_kernel<data_t>(
......
...@@ -103,11 +103,11 @@ class TestJITLoad(unittest.TestCase): ...@@ -103,11 +103,11 @@ class TestJITLoad(unittest.TestCase):
in str(e)) in str(e))
if IS_WINDOWS: if IS_WINDOWS:
self.assertTrue( self.assertTrue(
r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:48" r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:47"
in str(e)) in str(e))
else: else:
self.assertTrue( self.assertTrue(
"python/paddle/fluid/tests/custom_op/custom_relu_op.cc:48" "python/paddle/fluid/tests/custom_op/custom_relu_op.cc:47"
in str(e)) in str(e))
self.assertTrue(caught_exception) self.assertTrue(caught_exception)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册