未验证 提交 d2aaa751 编写于 作者: Z zyfncg 提交者: GitHub

Optimize performance of C++ Api (part2) (#40729)

* optimize performance of  C++ API

* remove stop_data_transform flag temparily
上级 72a2bfe2
...@@ -35,7 +35,7 @@ class BackendSet final { ...@@ -35,7 +35,7 @@ class BackendSet final {
: bitset_(b == Backend::UNDEFINED ? 0 : 1ULL << (static_cast<uint8_t>(b) - : bitset_(b == Backend::UNDEFINED ? 0 : 1ULL << (static_cast<uint8_t>(b) -
1)) {} 1)) {}
uint64_t bitset() const { return bitset_; } inline uint64_t bitset() const { return bitset_; }
bool inline Has(Backend b) const { bool inline Has(Backend b) const {
PD_CHECK(b != Backend::UNDEFINED, "Backend argument can't be UNDEFINED."); PD_CHECK(b != Backend::UNDEFINED, "Backend argument can't be UNDEFINED.");
......
...@@ -39,7 +39,7 @@ inline bool NeedTransformPlace(const paddle::platform::Place& input, ...@@ -39,7 +39,7 @@ inline bool NeedTransformPlace(const paddle::platform::Place& input,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_backend() && bool ret = transform_flag.need_trans_backend() &&
target != Backend::ALL_BACKEND && target != Backend::ALL_BACKEND &&
!platform::is_same_place(input, phi::TransToPhiPlace(target)); phi::TransToPhiBackend(input) != target;
return ret; return ret;
} }
...@@ -180,21 +180,20 @@ std::shared_ptr<phi::DenseTensor> PrepareData( ...@@ -180,21 +180,20 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
const auto& tensor_in = input.impl(); const auto& tensor_in = input.impl();
VLOG(6) << tensor_in->dtype() << "\t" << target_args_def.dtype; phi::DenseTensor& dense_tensor =
if (!transform_flag.NeedTransform() || !tensor_in->initialized() || *static_cast<phi::DenseTensor*>(tensor_in.get());
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace( (!NeedTransformPlace(
tensor_in->place(), target_args_def.backend, transform_flag) && dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType( !NeedTransformDataType(
tensor_in->dtype(), target_args_def.dtype, transform_flag) && dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout( !NeedTransformLayout(
tensor_in->layout(), target_args_def.layout, transform_flag))) { dense_tensor.layout(), target_args_def.layout, transform_flag))) {
return std::static_pointer_cast<phi::DenseTensor>(tensor_in); return std::static_pointer_cast<phi::DenseTensor>(tensor_in);
} }
phi::DenseTensor out = phi::DenseTensor out =
TransformData(*(static_cast<phi::DenseTensor*>(tensor_in.get())), TransformData(dense_tensor, target_args_def, transform_flag);
target_args_def,
transform_flag);
return std::make_shared<phi::DenseTensor>(std::move(out)); return std::make_shared<phi::DenseTensor>(std::move(out));
} }
......
...@@ -30,7 +30,7 @@ class DataTypeSet final { ...@@ -30,7 +30,7 @@ class DataTypeSet final {
? 0 ? 0
: 1ULL << (static_cast<uint8_t>(dtype) - 1)) {} : 1ULL << (static_cast<uint8_t>(dtype) - 1)) {}
uint64_t bitset() const { return bitset_; } inline uint64_t bitset() const { return bitset_; }
bool inline Has(DataType dtype) const { bool inline Has(DataType dtype) const {
PD_CHECK(dtype != DataType::UNDEFINED, PD_CHECK(dtype != DataType::UNDEFINED,
......
...@@ -16,13 +16,16 @@ limitations under the License. */ ...@@ -16,13 +16,16 @@ limitations under the License. */
#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#ifdef _MSC_VER
#include <intrin.h>
#endif
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
namespace detail { namespace detail {
BackendSet GetTensorBackendSet(const Tensor& t) { BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
BackendSet backend_set(phi::TransToPhiBackend(t.inner_place())); BackendSet backend_set(phi::TransToPhiBackend(t.place()));
switch (t.layout()) { switch (t.layout()) {
case DataLayout::MKLDNN: case DataLayout::MKLDNN:
backend_set = backend_set | BackendSet(Backend::MKLDNN); backend_set = backend_set | BackendSet(Backend::MKLDNN);
...@@ -35,6 +38,11 @@ BackendSet GetTensorBackendSet(const Tensor& t) { ...@@ -35,6 +38,11 @@ BackendSet GetTensorBackendSet(const Tensor& t) {
} }
std::size_t CountLeadingZeros(uint64_t val) { std::size_t CountLeadingZeros(uint64_t val) {
#if defined(__clang__) || defined(__GNUC__)
return __builtin_clzl(val);
#elif defined(_MSC_VER)
return __lzcnt64(val);
#else
if (val == 0) { if (val == 0) {
return 64; return 64;
} }
...@@ -48,6 +56,7 @@ std::size_t CountLeadingZeros(uint64_t val) { ...@@ -48,6 +56,7 @@ std::size_t CountLeadingZeros(uint64_t val) {
} }
} }
return zero_bits; return zero_bits;
#endif
} }
} // namespace detail } // namespace detail
......
...@@ -33,7 +33,7 @@ namespace paddle { ...@@ -33,7 +33,7 @@ namespace paddle {
namespace experimental { namespace experimental {
namespace detail { namespace detail {
BackendSet GetTensorBackendSet(const Tensor& t); BackendSet GetTensorBackendSet(const phi::TensorBase& t);
std::size_t CountLeadingZeros(uint64_t val); std::size_t CountLeadingZeros(uint64_t val);
} // namespace detail } // namespace detail
...@@ -93,11 +93,13 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -93,11 +93,13 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
// TODO(chenweihang): deal with multiple diff input Tensors // TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend // TODO(chenweihang): add global device guard method to set backend
void operator()(const Tensor& x) { void operator()(const Tensor& x) {
key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x); const phi::TensorBase& tensor = *x.impl();
// TODO(chenweihang): selecte multi layout and dtype key_set.backend_set =
key_set.layout = x.layout(); key_set.backend_set | detail::GetTensorBackendSet(tensor);
key_set.dtype = x.type(); // TODO(chenweihang): select multi layout and dtype
dtype_set = dtype_set | DataTypeSet(x.dtype()); key_set.layout = tensor.layout();
key_set.dtype = tensor.dtype();
dtype_set = dtype_set | DataTypeSet(key_set.dtype);
auto promote_result = PromoteTypes(dtype_set); auto promote_result = PromoteTypes(dtype_set);
if (promote_result != DataType::UNDEFINED) { if (promote_result != DataType::UNDEFINED) {
key_set.dtype = promote_result; key_set.dtype = promote_result;
...@@ -105,11 +107,12 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -105,11 +107,12 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
} }
void operator()(const std::vector<Tensor>& x) { void operator()(const std::vector<Tensor>& x) {
const phi::TensorBase& tensor = *x.at(0).impl();
key_set.backend_set = key_set.backend_set =
key_set.backend_set | detail::GetTensorBackendSet(x[0]); key_set.backend_set | detail::GetTensorBackendSet(tensor);
// TODO(chenweihang): selecte multi layout and dtype // TODO(chenweihang): select multi layout and dtype
key_set.layout = x[0].layout(); key_set.layout = tensor.layout();
key_set.dtype = x[0].type(); key_set.dtype = tensor.dtype();
} }
// skip other type args, these args don't used in kernel selection // skip other type args, these args don't used in kernel selection
......
...@@ -26,13 +26,14 @@ limitations under the License. */ ...@@ -26,13 +26,14 @@ limitations under the License. */
namespace phi { namespace phi {
Backend TransToPhiBackend(const phi::Place& place) { Backend TransToPhiBackend(const phi::Place& place) {
if (place.GetType() == phi::AllocationType::CPU) { auto allocation_type = place.GetType();
if (allocation_type == phi::AllocationType::CPU) {
return Backend::CPU; return Backend::CPU;
} else if (place.GetType() == phi::AllocationType::GPU) { } else if (allocation_type == phi::AllocationType::GPU) {
return Backend::GPU; return Backend::GPU;
} else if (place.GetType() == phi::AllocationType::XPU) { } else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU; return Backend::XPU;
} else if (place.GetType() == phi::AllocationType::CUSTOM) { } else if (allocation_type == phi::AllocationType::CUSTOM) {
return static_cast<Backend>( return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) + static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册