未验证 提交 a891032f 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Fix dtype unmatched in custom op API #31306

[Cherry-pick] Fix dtype unmatched in custom op API

cherry-pick of #31305
上级 628f0856
......@@ -57,7 +57,7 @@ class PD_DLL_DECL Tensor {
/// Reshape must be called before calling
/// mutable_data() or copy_to(const PlaceType& place)
/// \param shape The shape to set.
void reshape(const std::vector<int>& shape);
void reshape(const std::vector<int64_t>& shape);
/// \brief Get the memory pointer in CPU or GPU with
/// specific data type.
......@@ -90,7 +90,7 @@ class PD_DLL_DECL Tensor {
Tensor copy_to(const PlaceType& place) const;
/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
std::vector<int64_t> shape() const;
/// \brief Return the data type of the tensor.
/// It's usually used to get the output tensor data type.
......
......@@ -95,7 +95,7 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
} \
auto *tensor = static_cast<framework::LoDTensor *>(tensor_.get());
void Tensor::reshape(const std::vector<int> &shape) {
void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR
tensor->Resize(framework::make_ddim(shape));
}
......@@ -251,9 +251,9 @@ template PD_DLL_DECL int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PD_DLL_DECL bool *Tensor::mutable_data<bool>(const PlaceType &place);
std::vector<int> Tensor::shape() const {
std::vector<int64_t> Tensor::shape() const {
GET_CASTED_TENSOR
return framework::vectorize<int>(tensor->dims());
return framework::vectorize<int64_t>(tensor->dims());
}
const PlaceType &Tensor::place() const {
......
......@@ -20,7 +20,7 @@
template <typename T>
paddle::Tensor InitCPUTensorForTest() {
std::vector<int> tensor_shape{5, 5};
std::vector<int64_t> tensor_shape{5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
......@@ -54,7 +54,7 @@ void TestCopyTensor() {
}
void TestAPIPlace() {
std::vector<int> tensor_shape = {5, 5};
std::vector<int64_t> tensor_shape = {5, 5};
#ifdef PADDLE_WITH_CUDA
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
t1.reshape(tensor_shape);
......@@ -68,7 +68,7 @@ void TestAPIPlace() {
}
void TestAPISizeAndShape() {
std::vector<int> tensor_shape = {5, 5};
std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
CHECK_EQ(t1.size(), 25);
......@@ -77,7 +77,7 @@ void TestAPISizeAndShape() {
template <typename T>
paddle::DataType TestDtype() {
std::vector<int> tensor_shape = {5, 5};
std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
......@@ -86,7 +86,7 @@ paddle::DataType TestDtype() {
template <typename T>
void TestCast(paddle::DataType data_type) {
std::vector<int> tensor_shape = {5, 5};
std::vector<int64_t> tensor_shape = {5, 5};
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
t1.reshape(tensor_shape);
t1.template mutable_data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册