未验证 提交 a891032f 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Fix dtype unmatched in custom op API #31306

[Cherry-pick] Fix dtype unmatched in custom op API

cherry-pick of #31305
上级 628f0856
...@@ -57,7 +57,7 @@ class PD_DLL_DECL Tensor { ...@@ -57,7 +57,7 @@ class PD_DLL_DECL Tensor {
/// Reshape must be called before calling /// Reshape must be called before calling
/// mutable_data() or copy_to(const PlaceType& place) /// mutable_data() or copy_to(const PlaceType& place)
/// \param shape The shape to set. /// \param shape The shape to set.
void reshape(const std::vector<int>& shape); void reshape(const std::vector<int64_t>& shape);
/// \brief Get the memory pointer in CPU or GPU with /// \brief Get the memory pointer in CPU or GPU with
/// specific data type. /// specific data type.
...@@ -90,7 +90,7 @@ class PD_DLL_DECL Tensor { ...@@ -90,7 +90,7 @@ class PD_DLL_DECL Tensor {
Tensor copy_to(const PlaceType& place) const; Tensor copy_to(const PlaceType& place) const;
/// \brief Return the shape of the Tensor. /// \brief Return the shape of the Tensor.
std::vector<int> shape() const; std::vector<int64_t> shape() const;
/// \brief Return the data type of the tensor. /// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type. /// It's usually used to get the output tensor data type.
......
...@@ -95,7 +95,7 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, ...@@ -95,7 +95,7 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
} \ } \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get()); auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
void Tensor::reshape(const std::vector<int> &shape) { void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} }
...@@ -251,9 +251,9 @@ template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>( ...@@ -251,9 +251,9 @@ template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place); const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
std::vector<int> Tensor::shape() const { std::vector<int64_t> Tensor::shape() const {
GET_CASTED_TENSOR GET_CASTED_TENSOR
return framework::vectorize<int>(tensor->dims()); return framework::vectorize<int64_t>(tensor->dims());
} }
const PlaceType &Tensor::place() const { const PlaceType &Tensor::place() const {
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
template <typename T> template <typename T>
paddle::Tensor InitCPUTensorForTest() { paddle::Tensor InitCPUTensorForTest() {
std::vector<int> tensor_shape{5, 5}; std::vector<int64_t> tensor_shape{5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU); auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
...@@ -54,7 +54,7 @@ void TestCopyTensor() { ...@@ -54,7 +54,7 @@ void TestCopyTensor() {
} }
void TestAPIPlace() { void TestAPIPlace() {
std::vector<int> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU); auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
...@@ -68,7 +68,7 @@ void TestAPIPlace() { ...@@ -68,7 +68,7 @@ void TestAPIPlace() {
} }
void TestAPISizeAndShape() { void TestAPISizeAndShape() {
std::vector<int> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
CHECK_EQ(t1.size(), 25); CHECK_EQ(t1.size(), 25);
...@@ -77,7 +77,7 @@ void TestAPISizeAndShape() { ...@@ -77,7 +77,7 @@ void TestAPISizeAndShape() {
template <typename T> template <typename T>
paddle::DataType TestDtype() { paddle::DataType TestDtype() {
std::vector<int> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
t1.template mutable_data<T>(); t1.template mutable_data<T>();
...@@ -86,7 +86,7 @@ paddle::DataType TestDtype() { ...@@ -86,7 +86,7 @@ paddle::DataType TestDtype() {
template <typename T> template <typename T>
void TestCast(paddle::DataType data_type) { void TestCast(paddle::DataType data_type) {
std::vector<int> tensor_shape = {5, 5}; std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU); auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape); t1.reshape(tensor_shape);
t1.template mutable_data<T>(); t1.template mutable_data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册