未验证 提交 056edf39 编写于 作者: A Adam 提交者: GitHub

Change ShareDataWith() to TensorCopy() in conv_mkldnn (#22695)

上级 432a4b27
......@@ -48,6 +48,19 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
}
auto size = src.numel() * SizeOfType(src.type());
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
PADDLE_ENFORCE_EQ(
src.memory_size(), dst->memory_size(),
platform::errors::InvalidArgument(
"When copying tensor with MKL-DNN data layout, "
"memory size of source tensor should be the same as memory size of "
"destination tensor. "
"But received src.memory_size = %d, dst.memory_size = %d.",
src.memory_size(), dst->memory_size()));
size = src.memory_size();
}
#endif
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
......
......@@ -316,8 +316,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace());
auto output_data = output->mutable_data<T>(
ctx.GetPlace(), residual_param->memory_size());
framework::TensorCopy(*residual_param, residual_param->place(), output);
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
......@@ -610,7 +611,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
} else {
output->ShareDataWith(*residual_param);
framework::TensorCopy(*residual_param, residual_param->place(),
output);
dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
}
need_s8_to_u8 =
......@@ -681,7 +683,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
output->ShareDataWith(*residual_param);
framework::TensorCopy(*residual_param, residual_param->place(), output);
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册