未验证 提交 5b7fadec 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Replace Backend by Place in C++ API (#40732)

* replace Backend by Place in C++ API

* fix left code

* fix test_to_api bug
上级 67b46e45
...@@ -395,7 +395,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -395,7 +395,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard; platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) { for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId()); gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::Backend::GPU); auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::GPUPlace());
barrierTensors.push_back(dt); barrierTensors.push_back(dt);
} }
auto task = ProcessGroupNCCL::AllReduce(barrierTensors); auto task = ProcessGroupNCCL::AllReduce(barrierTensors);
......
...@@ -321,7 +321,7 @@ EagerReducer::EagerReducer( ...@@ -321,7 +321,7 @@ EagerReducer::EagerReducer(
if (find_unused_vars_each_step_) { if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty( global_used_vars_ = paddle::experimental::empty(
ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32, ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
TransToBackend(inner_place_)); inner_place_);
} }
} }
...@@ -363,10 +363,8 @@ void EagerReducer::InitializeGroups( ...@@ -363,10 +363,8 @@ void EagerReducer::InitializeGroups(
} else { } else {
// process the dense gradient. // process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group); InitializeDenseGroups(tensor_indices_, &group);
// experimental::Backend backend = TransToBackend(inner_place_);
group.dense_contents_ = paddle::experimental::empty( group.dense_contents_ = paddle::experimental::empty(
ScalarArray({group.all_length_}), group.dtype_, ScalarArray({group.all_length_}), group.dtype_, inner_place_);
TransToBackend(inner_place_));
} }
// map tensors to this group by VariableLocator // map tensors to this group by VariableLocator
......
...@@ -43,8 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue( ...@@ -43,8 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
const phi::DataType& dtype, const phi::DataLayout& layout, float value, const phi::DataType& dtype, const phi::DataLayout& layout, float value,
bool is_leaf) { bool is_leaf) {
paddle::experimental::Tensor out = paddle::experimental::full( paddle::experimental::Tensor out = paddle::experimental::full(
phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype, phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype, place);
phi::TransToPhiBackend(place));
auto meta = EagerUtils::autograd_meta(&out); auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) { if (is_leaf) {
......
...@@ -29,7 +29,7 @@ yaml_types_mapping = { ...@@ -29,7 +29,7 @@ yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \ 'str' : 'std::string', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Place' : 'paddle::experimental::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>', 'Tensor[]' : 'std::vector<Tensor>',
......
...@@ -43,7 +43,7 @@ atype_to_parsing_function = { ...@@ -43,7 +43,7 @@ atype_to_parsing_function = {
"std::vector<std::string>": "CastPyArg2Strings", "std::vector<std::string>": "CastPyArg2Strings",
"paddle::experimental::Scalar": "CastPyArg2Scalar", "paddle::experimental::Scalar": "CastPyArg2Scalar",
"paddle::experimental::ScalarArray": "CastPyArg2ScalarArray", "paddle::experimental::ScalarArray": "CastPyArg2ScalarArray",
"paddle::experimental::Backend": "CastPyArg2Backend", "paddle::experimental::Place": "CastPyArg2Place",
"paddle::experimental::DataType": "CastPyArg2DataType", "paddle::experimental::DataType": "CastPyArg2DataType",
} }
......
...@@ -132,8 +132,7 @@ void InitTensorWithTensor(TensorObject* self, ...@@ -132,8 +132,7 @@ void InitTensorWithTensor(TensorObject* self,
self->tensor.set_impl(impl); self->tensor.set_impl(impl);
VLOG(4) << "Same place, do ShareDataWith"; VLOG(4) << "Same place, do ShareDataWith";
} else { } else {
self->tensor.set_impl( self->tensor.set_impl(src.copy_to(place, true).impl());
src.copy_to(phi::TransToPhiBackend(place), true).impl());
VLOG(4) << "Different place, do TensorCopy"; VLOG(4) << "Different place, do TensorCopy";
} }
if (src.get_autograd_meta()) { if (src.get_autograd_meta()) {
...@@ -156,8 +155,7 @@ void InitTensorWithFrameworkTensor(TensorObject* self, ...@@ -156,8 +155,7 @@ void InitTensorWithFrameworkTensor(TensorObject* self,
} else { } else {
auto temp = auto temp =
paddle::experimental::Tensor(std::make_shared<phi::DenseTensor>(src)); paddle::experimental::Tensor(std::make_shared<phi::DenseTensor>(src));
self->tensor.set_impl( self->tensor.set_impl(temp.copy_to(place, true).impl());
temp.copy_to(phi::TransToPhiBackend(place), true).impl());
VLOG(4) << "Different place, do TensorCopy"; VLOG(4) << "Different place, do TensorCopy";
} }
egr::EagerUtils::autograd_meta(&(self->tensor))->SetPersistable(false); egr::EagerUtils::autograd_meta(&(self->tensor))->SetPersistable(false);
......
...@@ -159,7 +159,7 @@ static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args, ...@@ -159,7 +159,7 @@ static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args,
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2); auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3); bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);
dst = src.copy_to(phi::TransToPhiBackend(place), blocking); dst = src.copy_to(place, blocking);
egr::EagerUtils::autograd_meta(&dst)->SetStopGradient( egr::EagerUtils::autograd_meta(&dst)->SetStopGradient(
egr::EagerUtils::autograd_meta(&(src))->StopGradient()); egr::EagerUtils::autograd_meta(&(src))->StopGradient());
egr::EagerUtils::autograd_meta(&dst)->SetPersistable( egr::EagerUtils::autograd_meta(&dst)->SetPersistable(
......
...@@ -218,8 +218,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args, ...@@ -218,8 +218,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
EAGER_TRY EAGER_TRY
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0); auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1); bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
auto cp_tensor = auto cp_tensor = self->tensor.copy_to(place, blocking);
self->tensor.copy_to(phi::TransToPhiBackend(place), blocking);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true); egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor) egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable( ->SetPersistable(
...@@ -231,8 +230,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args, ...@@ -231,8 +230,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
static PyObject* tensor_method_cpu(TensorObject* self, PyObject* args, static PyObject* tensor_method_cpu(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
auto cp_tensor = auto cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
self->tensor.copy_to(phi::TransToPhiBackend(phi::CPUPlace()), true);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true); egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor) egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable( ->SetPersistable(
......
...@@ -929,28 +929,10 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs( ...@@ -929,28 +929,10 @@ std::vector<paddle::framework::Scope*> GetScopePtrListFromArgs(
return result; return result;
} }
paddle::experimental::Backend CastPyArg2Backend(PyObject* obj, paddle::experimental::Place CastPyArg2Place(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
if (obj == Py_None) { return CastPyArg2Place(obj, arg_pos);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int or place, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (type_name == "int") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return static_cast<paddle::experimental::Backend>(value);
} else {
platform::Place place = CastPyArg2Place(obj, arg_pos);
return phi::TransToPhiBackend(place);
}
return paddle::experimental::Backend::CPU;
} }
paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, paddle::experimental::DataType CastPyArg2DataType(PyObject* obj,
......
...@@ -154,7 +154,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, ...@@ -154,7 +154,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
paddle::experimental::ScalarArray CastPyArg2ScalarArray( paddle::experimental::ScalarArray CastPyArg2ScalarArray(
PyObject* obj, const std::string& op_type, ssize_t arg_pos); PyObject* obj, const std::string& op_type, ssize_t arg_pos);
paddle::experimental::Backend CastPyArg2Backend(PyObject* obj, paddle::experimental::Place CastPyArg2Place(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
......
...@@ -31,7 +31,6 @@ using gpuStream_t = hipStream_t; ...@@ -31,7 +31,6 @@ using gpuStream_t = hipStream_t;
#include "paddle/phi/api/ext/dll_decl.h" #include "paddle/phi/api/ext/dll_decl.h"
#include "paddle/phi/api/ext/place.h" #include "paddle/phi/api/ext/place.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -415,11 +414,11 @@ class PADDLE_API Tensor final { ...@@ -415,11 +414,11 @@ class PADDLE_API Tensor final {
/** /**
* @brief Transfer the current Tensor to the specified device and return. * @brief Transfer the current Tensor to the specified device and return.
* *
* @param backend, The target backend of which the tensor will copy to. * @param place, The target place of which the tensor will copy to.
* @param blocking, Should we copy this in sync way. * @param blocking, Should we copy this in sync way.
* @return Tensor * @return Tensor
*/ */
Tensor copy_to(Backend backend, bool blocking) const; Tensor copy_to(Place place, bool blocking) const;
/** /**
* @brief Transfer the source Tensor to current Tensor. * @brief Transfer the source Tensor to current Tensor.
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
...@@ -31,9 +32,10 @@ limitations under the License. */ ...@@ -31,9 +32,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) { Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend); kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key); "copy", kernel_key);
...@@ -57,8 +59,7 @@ Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) { ...@@ -57,8 +59,7 @@ Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
phi::DenseTensor*); phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)( (*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
*dev_ctx, *dense_x, phi::TransToPhiPlace(backend), blocking, kernel_out);
return out; return out;
} }
......
...@@ -15,15 +15,14 @@ limitations under the License. */ ...@@ -15,15 +15,14 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking);
std::vector<Tensor> split_impl(const Tensor& x, std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections, const ScalarArray& num_or_sections,
......
...@@ -82,13 +82,17 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) { ...@@ -82,13 +82,17 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) {
return dtype != DataType::UNDEFINED ? dtype : ParseDataType(tensor); return dtype != DataType::UNDEFINED ? dtype : ParseDataType(tensor);
} }
Backend ParseBackend(Backend backend) { return backend; } Backend ParseBackend(const Place& place) {
return phi::TransToPhiBackend(place);
}
Backend ParseBackend(const Tensor& tensor) { Backend ParseBackend(const Tensor& tensor) {
return phi::TransToPhiBackend(tensor.inner_place()); return phi::TransToPhiBackend(tensor.inner_place());
} }
Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor) { Backend ParseBackendWithInputOrder(const Place& place, const Tensor& tensor) {
return backend != Backend::UNDEFINED ? backend : ParseBackend(tensor); return place.GetType() != phi::AllocationType::UNDEFINED
? ParseBackend(place)
: ParseBackend(tensor);
} }
DataLayout ParseLayout(DataLayout layout) { return layout; } DataLayout ParseLayout(DataLayout layout) { return layout; }
......
...@@ -154,7 +154,7 @@ DataType ParseDataType(const Tensor& tensor); ...@@ -154,7 +154,7 @@ DataType ParseDataType(const Tensor& tensor);
DataType ParseDataType(const std::vector<Tensor>& tensors); DataType ParseDataType(const std::vector<Tensor>& tensors);
DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor); DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor);
Backend ParseBackend(Backend backend); Backend ParseBackend(const Place& place);
Backend ParseBackend(const Tensor& tensor); Backend ParseBackend(const Tensor& tensor);
template <typename T, typename... Args> template <typename T, typename... Args>
Backend ParseBackend(T t, Args... args) { Backend ParseBackend(T t, Args... args) {
...@@ -163,7 +163,7 @@ Backend ParseBackend(T t, Args... args) { ...@@ -163,7 +163,7 @@ Backend ParseBackend(T t, Args... args) {
return static_cast<Backend>(64 - return static_cast<Backend>(64 -
detail::CountLeadingZeros(backend_set.bitset())); detail::CountLeadingZeros(backend_set.bitset()));
} }
Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor); Backend ParseBackendWithInputOrder(const Place& place, const Tensor& tensor);
DataLayout ParseLayout(DataLayout layout); DataLayout ParseLayout(DataLayout layout);
DataLayout ParseLayout(const Tensor& tensor); DataLayout ParseLayout(const Tensor& tensor);
......
...@@ -27,14 +27,14 @@ namespace paddle { ...@@ -27,14 +27,14 @@ namespace paddle {
namespace experimental { namespace experimental {
// declare cast api // declare cast api
Tensor cast(const Tensor &x, DataType out_dtype); Tensor cast(const Tensor &x, DataType out_dtype);
Tensor copy_to(const Tensor &x, Backend backend, bool blocking); Tensor copy_to(const Tensor &x, Place place, bool blocking);
Tensor Tensor::cast(DataType target_type) const { Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type); return experimental::cast(*this, target_type);
} }
Tensor Tensor::copy_to(Backend backend, bool blocking) const { Tensor Tensor::copy_to(Place place, bool blocking) const {
return experimental::copy_to(*this, backend, blocking); return experimental::copy_to(*this, place, blocking);
} }
template <typename T> template <typename T>
...@@ -44,7 +44,7 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const { ...@@ -44,7 +44,7 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
"`copy_to` method without template argument instead. " "`copy_to` method without template argument instead. "
"reason: copying a Tensor to another device does not need " "reason: copying a Tensor to another device does not need "
"to specify the data type template argument."; "to specify the data type template argument.";
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false); return copy_to(ConvertExtPlaceToInnerPlace(target_place), /*blocking=*/false);
} }
template PADDLE_API Tensor template PADDLE_API Tensor
......
...@@ -203,5 +203,10 @@ namespace paddle { ...@@ -203,5 +203,10 @@ namespace paddle {
namespace experimental { namespace experimental {
using AllocationType = phi::AllocationType; using AllocationType = phi::AllocationType;
using Place = phi::Place; using Place = phi::Place;
using CPUPlace = phi::CPUPlace;
using GPUPlace = phi::GPUPlace;
using GPUPinnedPlace = phi::GPUPinnedPlace;
using XPUPlace = phi::XPUPlace;
using NPUPlace = phi::NPUPlace;
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -39,10 +40,10 @@ TEST(API, data_transform_same_place) { ...@@ -39,10 +40,10 @@ TEST(API, data_transform_same_place) {
auto x = paddle::experimental::full({3, 3}, auto x = paddle::experimental::full({3, 3},
1.0, 1.0,
experimental::DataType::COMPLEX128, experimental::DataType::COMPLEX128,
experimental::Backend::CPU); experimental::CPUPlace());
auto y = paddle::experimental::full( auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::Backend::CPU); {3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::CPUPlace());
std::vector<phi::dtype::complex<double>> sum(9, 6.0); std::vector<phi::dtype::complex<double>> sum(9, 6.0);
...@@ -74,10 +75,10 @@ TEST(API, data_transform_same_place) { ...@@ -74,10 +75,10 @@ TEST(API, data_transform_same_place) {
TEST(Tensor, data_transform_diff_place) { TEST(Tensor, data_transform_diff_place) {
// 1. create tensor // 1. create tensor
auto x = paddle::experimental::full( auto x = paddle::experimental::full(
{3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::Backend::CPU); {3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::CPUPlace());
auto y = paddle::experimental::full( auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::Backend::GPU); {3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::GPUPlace());
std::vector<float> sum(9, 6.0); std::vector<float> sum(9, 6.0);
...@@ -95,7 +96,7 @@ TEST(Tensor, data_transform_diff_place) { ...@@ -95,7 +96,7 @@ TEST(Tensor, data_transform_diff_place) {
ASSERT_EQ(out.impl()->place(), ASSERT_EQ(out.impl()->place(),
phi::TransToPhiPlace(experimental::Backend::GPU)); phi::TransToPhiPlace(experimental::Backend::GPU));
auto ref_out = experimental::copy_to(out, experimental::Backend::CPU, true); auto ref_out = experimental::copy_to(out, experimental::CPUPlace(), true);
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(ref_out.impl()); auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(ref_out.impl());
for (size_t i = 0; i < 9; i++) { for (size_t i = 0; i < 9; i++) {
......
...@@ -30,7 +30,7 @@ namespace tests { ...@@ -30,7 +30,7 @@ namespace tests {
TEST(API, scale) { TEST(API, scale) {
auto x = experimental::full( auto x = experimental::full(
{3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::Backend::CPU); {3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::CPUPlace());
const size_t cycles = 300; const size_t cycles = 300;
phi::tests::Timer timer; phi::tests::Timer timer;
......
...@@ -69,10 +69,10 @@ TEST(API, copy_to) { ...@@ -69,10 +69,10 @@ TEST(API, copy_to) {
// 2. test API // 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = paddle::experimental::copy_to(x, phi::Backend::GPU, false); auto tmp = paddle::experimental::copy_to(x, phi::GPUPlace(), false);
auto out = paddle::experimental::copy_to(tmp, phi::Backend::CPU, true); auto out = paddle::experimental::copy_to(tmp, phi::CPUPlace(), true);
#else #else
auto out = paddle::experimental::copy_to(x, phi::Backend::CPU, false); auto out = paddle::experimental::copy_to(x, phi::CPUPlace(), false);
#endif #endif
// 3. check result // 3. check result
...@@ -85,10 +85,10 @@ TEST(Tensor, copy_to) { ...@@ -85,10 +85,10 @@ TEST(Tensor, copy_to) {
// 2. test API // 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = x.copy_to(phi::Backend::GPU, false); auto tmp = x.copy_to(phi::GPUPlace(), false);
auto out = tmp.copy_to(phi::Backend::CPU, true); auto out = tmp.copy_to(phi::CPUPlace(), true);
#else #else
auto out = x.copy_to(phi::Backend::CPU, false); auto out = x.copy_to(phi::CPUPlace(), false);
#endif #endif
// 3. check result // 3. check result
......
...@@ -36,9 +36,9 @@ ...@@ -36,9 +36,9 @@
func : conj func : conj
- api : copy_to - api : copy_to
args : (Tensor x, Backend backend, bool blocking) args : (Tensor x, Place place, bool blocking)
output : Tensor output : Tensor
invoke : copy_to_impl(x, backend, blocking) invoke : copy_to_impl(x, place, blocking)
- api : divide - api : divide
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
...@@ -57,7 +57,7 @@ ...@@ -57,7 +57,7 @@
func : dot func : dot
- api : empty - api : empty
args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU) args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
...@@ -69,7 +69,7 @@ ...@@ -69,7 +69,7 @@
backend : place backend : place
- api : empty_like - api : empty_like
args : (Tensor x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED) args : (Tensor x, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
...@@ -89,7 +89,7 @@ ...@@ -89,7 +89,7 @@
func : flatten func : flatten
- api : full - api : full
args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU) args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
...@@ -101,7 +101,7 @@ ...@@ -101,7 +101,7 @@
backend : place backend : place
- api : full_like - api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED) args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
...@@ -138,7 +138,7 @@ ...@@ -138,7 +138,7 @@
func : multiply func : multiply
- api : ones_like - api : ones_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place={})
output : Tensor output : Tensor
invoke : full_like(x, 1, dtype, place) invoke : full_like(x, 1, dtype, place)
...@@ -218,7 +218,7 @@ ...@@ -218,7 +218,7 @@
data_type : x data_type : x
- api : zeros_like - api : zeros_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {})
output : Tensor output : Tensor
invoke : full_like(x, 0, dtype, place) invoke : full_like(x, 0, dtype, place)
......
...@@ -99,7 +99,7 @@ class BaseAPI(object): ...@@ -99,7 +99,7 @@ class BaseAPI(object):
'double': 'double', 'double': 'double',
'bool': 'bool', 'bool': 'bool',
'str': 'const std::string&', 'str': 'const std::string&',
'Backend': 'Backend', 'Place': 'Place',
'DataLayout': 'DataLayout', 'DataLayout': 'DataLayout',
'DataType': 'DataType', 'DataType': 'DataType',
'int64[]': 'const std::vector<int64_t>&', 'int64[]': 'const std::vector<int64_t>&',
...@@ -118,7 +118,7 @@ class BaseAPI(object): ...@@ -118,7 +118,7 @@ class BaseAPI(object):
'float': 'paddle::optional<float>', 'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>', 'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>', 'bool': 'paddle::optional<bool>',
'Backend': 'paddle::optional<Backend>', 'Place': 'paddle::optional<Place>',
'DataLayout': 'paddle::optional<DataLayout>', 'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>', 'DataType': 'paddle::optional<DataType>',
'int64[]': 'paddle::optional<std::vector<int64_t>>', 'int64[]': 'paddle::optional<std::vector<int64_t>>',
...@@ -327,9 +327,9 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -327,9 +327,9 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
attr_layout_count = 0 attr_layout_count = 0
attr_data_type_count = 0 attr_data_type_count = 0
for attr_name in attrs['names']: for attr_name in attrs['names']:
if attrs['attr_info'][attr_name][0] == 'Backend': if attrs['attr_info'][attr_name][0] == 'Place':
assert kernel['backend'] is not None, \ assert kernel['backend'] is not None, \
f"{api} api: When there is a parameter with 'Backend' type in attributes, you must set backend of kernel manually." f"{api} api: When there is a parameter with 'Place' type in attributes, you must set backend of kernel manually."
attr_backend_count = attr_backend_count + 1 attr_backend_count = attr_backend_count + 1
if attrs['attr_info'][attr_name][0] == 'DataLayout': if attrs['attr_info'][attr_name][0] == 'DataLayout':
assert kernel['layout'] is not None, \ assert kernel['layout'] is not None, \
...@@ -348,8 +348,8 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -348,8 +348,8 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
assert len( assert len(
vars_list vars_list
) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}." ) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Backend'), \ assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Backend type." f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
kernel_select_code = kernel_select_code + f""" kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()}); kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册