未验证 提交 e1a792fe 编写于 作者: H HongyuJia 提交者: GitHub

[Bug Fix] Fix NLP-Bert model performance loss (#50333)

* fix NLP-Bert model performance loss

* fix windows compile error
上级 ffbda80c
......@@ -20,7 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
......@@ -131,10 +133,31 @@ inline bool backends_are_same_class(const phi::Backend& l,
return phi::TransToPhiPlace(l) == phi::TransToPhiPlace(r);
}
inline bool NeedTransform(const phi::KernelKey& l, const phi::KernelKey& r) {
return !backends_are_same_class(l.backend(), r.backend()) ||
NeedTransformDataType(l, r) ||
NeedTransformLayout(l.layout(), r.layout());
inline bool NeedTransformBackend(const phi::Backend& type_for_var_backend,
const phi::Backend& expected_backend,
const phi::DenseTensor& tensor) {
// NOTE(jiahongyu): KernelKey does not hold place information, so we need to
// explicitly transform CUDAPinnedPlace->CUDAPlace
if (type_for_var_backend != phi::Backend::ALL_BACKEND &&
paddle::platform::is_cuda_pinned_place(tensor.place()) &&
expected_backend != phi::Backend::CPU) {
VLOG(3) << "Transform Variable " << tensor.name() << " from "
<< tensor.place() << " to "
<< phi::TransToPhiPlace(expected_backend);
return true;
}
return !backends_are_same_class(type_for_var_backend, expected_backend);
}
inline bool NeedTransform(const phi::KernelKey& kernel_type_for_var,
const phi::KernelKey& expected_kernel_key,
const phi::DenseTensor& tensor) {
return NeedTransformBackend(kernel_type_for_var.backend(),
expected_kernel_key.backend(),
tensor) ||
NeedTransformDataType(kernel_type_for_var, expected_kernel_key) ||
NeedTransformLayout(kernel_type_for_var.layout(),
expected_kernel_key.layout());
}
} // namespace framework
......
......@@ -87,8 +87,8 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
if (tensor && tensor->IsInitialized() && (tensor->memory_size() != 0)) {
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!framework::NeedTransform(kernel_type_for_var,
expected_kernel_key)) {
if (!framework::NeedTransform(
kernel_type_for_var, expected_kernel_key, *tensor)) {
continue;
} else {
VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册