未验证 提交 d666c7df 编写于 作者: P Paulina Gacek 提交者: GitHub

[PHI] OneDNN version of Copy (#48539)

* OneDNN version of Copy, tranpose kernels adjusted

* style fixes in tranpose_grad

* redundant headers deleted
上级 69e695b7
...@@ -56,6 +56,9 @@ void Copy(const Context& dev_ctx, ...@@ -56,6 +56,9 @@ void Copy(const Context& dev_ctx,
void* dst_ptr = nullptr; void* dst_ptr = nullptr;
if (paddle::platform::is_cpu_place(dst_place)) { if (paddle::platform::is_cpu_place(dst_place)) {
dst_ptr = dev_ctx.HostAlloc(dst, src.dtype()); dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
#ifdef PADDLE_WITH_MKLDNN
dst->set_layout(src.layout());
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (paddle::platform::is_gpu_place(dst_place) || } else if (paddle::platform::is_gpu_place(dst_place) ||
paddle::platform::is_cuda_pinned_place(dst_place)) { paddle::platform::is_cuda_pinned_place(dst_place)) {
...@@ -81,7 +84,7 @@ void Copy(const Context& dev_ctx, ...@@ -81,7 +84,7 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dst->place(), dst->place(),
dst_place, dst_place,
phi::errors::Unavailable( errors::Unavailable(
"The Dst Tensor's place and dst_place do not match, Tensor's place " "The Dst Tensor's place and dst_place do not match, Tensor's place "
"place is %s, dst_place is %s.", "place is %s, dst_place is %s.",
dst->place(), dst->place(),
...@@ -112,13 +115,13 @@ void Copy(const Context& dev_ctx, ...@@ -112,13 +115,13 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place), paddle::platform::is_gpu_place(ctx_place),
true, true,
phi::errors::PreconditionNotMet( errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.", "Context place error, excepted GPUPlace, but actually %s.",
ctx_place)); ctx_place));
auto ctx_gpu_place = ctx_place; auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(src_gpu_place, PADDLE_ENFORCE_EQ(src_gpu_place,
ctx_gpu_place, ctx_gpu_place,
phi::errors::Unavailable( errors::Unavailable(
"Source place and context place do not match, source " "Source place and context place do not match, source "
"place is %s, context place is %s.", "place is %s, context place is %s.",
src_gpu_place, src_gpu_place,
...@@ -137,14 +140,14 @@ void Copy(const Context& dev_ctx, ...@@ -137,14 +140,14 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place), paddle::platform::is_gpu_place(ctx_place),
true, true,
phi::errors::PreconditionNotMet( errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.", "Context place error, excepted GPUPlace, but actually %s.",
ctx_place)); ctx_place));
auto ctx_gpu_place = ctx_place; auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(dst_gpu_place, PADDLE_ENFORCE_EQ(
dst_gpu_place,
ctx_gpu_place, ctx_gpu_place,
phi::errors::Unavailable( errors::Unavailable("Destination place and context place do not match, "
"Destination place and context place do not match, "
"destination place is %s, context place is %s.", "destination place is %s, context place is %s.",
dst_gpu_place, dst_gpu_place,
ctx_gpu_place)); ctx_gpu_place));
...@@ -161,7 +164,7 @@ void Copy(const Context& dev_ctx, ...@@ -161,7 +164,7 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place), paddle::platform::is_gpu_place(ctx_place),
true, true,
phi::errors::PreconditionNotMet( errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.", "Context place error, excepted GPUPlace, but actually %s.",
ctx_place)); ctx_place));
auto stream = auto stream =
...@@ -184,7 +187,7 @@ void Copy(const Context& dev_ctx, ...@@ -184,7 +187,7 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy( paddle::memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
} else { } else {
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(errors::Unavailable(
"Context place dose not match the source and destination place.")); "Context place dose not match the source and destination place."));
} }
} }
...@@ -196,13 +199,13 @@ void Copy(const Context& dev_ctx, ...@@ -196,13 +199,13 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place), paddle::platform::is_gpu_place(ctx_place),
true, true,
phi::errors::PreconditionNotMet( errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.", "Context place error, excepted GPUPlace, but actually %s.",
ctx_place)); ctx_place));
auto ctx_gpu_place = ctx_place; auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(src_gpu_place, PADDLE_ENFORCE_EQ(src_gpu_place,
ctx_gpu_place, ctx_gpu_place,
phi::errors::Unavailable( errors::Unavailable(
"Source place and context place do not match, source " "Source place and context place do not match, source "
"place is %s, context place is %s.", "place is %s, context place is %s.",
src_gpu_place, src_gpu_place,
...@@ -259,7 +262,7 @@ void Copy(const Context& dev_ctx, ...@@ -259,7 +262,7 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream); paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
#endif #endif
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place)); "Copy from %s to %s is not supported.", src_place, dst_place));
} }
} }
...@@ -411,4 +414,12 @@ template void Copy(const CustomContext& dev_ctx, ...@@ -411,4 +414,12 @@ template void Copy(const CustomContext& dev_ctx,
bool blocking, bool blocking,
DenseTensor* dst); DenseTensor* dst);
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
template void Copy(const OneDNNContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
bool blocking,
DenseTensor* dst);
#endif
} // namespace phi } // namespace phi
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/transpose_grad_kernel.h" #include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -24,16 +22,16 @@ void TransposeGradKernel(const Context& dev_ctx, ...@@ -24,16 +22,16 @@ void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::vector<int>& axis, const std::vector<int>& axis,
DenseTensor* x_grad) { DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU, PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == AllocationType::CPU,
true, true,
errors::PreconditionNotMet( errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace")); "oneDNN TransposeGrad kernel must use CPUPlace"));
if (!x_grad) return; if (!x_grad) return;
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
if (axis.size() == 1) { if (axis.size() == 1) {
paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad); Copy<Context>(dev_ctx, out_grad, out_grad.place(), false, x_grad);
x_grad->set_mem_desc(out_grad.mem_desc()); x_grad->set_mem_desc(out_grad.mem_desc());
return; return;
} }
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -80,7 +79,7 @@ void TransposeKernel(const Context& dev_ctx, ...@@ -80,7 +79,7 @@ void TransposeKernel(const Context& dev_ctx,
dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc()); dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc());
if (axis.size() == 1) { if (axis.size() == 1) {
paddle::framework::TensorCopy(x, x.place(), out); Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc()); out->set_mem_desc(x.mem_desc());
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册