提交 d9b202e7 编写于 作者: M minqiyang

Move tensor copy src_ptr and dst_ptr check to TensorCopy function

test=develop
上级 f4084882
...@@ -114,6 +114,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -114,6 +114,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data from " << src.place() << " to "
<< dst_place;
return;
}
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
...@@ -132,6 +137,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -132,6 +137,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place); auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place); auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
if (src_ptr == dst_ptr &&
src_gpu_place.GetDeviceId() == dst_gpu_place.GetDeviceId()) {
VLOG(3) << "Skip copy the same data from " << src.place() << " to "
<< dst_place;
return;
}
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
} }
#endif #endif
......
...@@ -195,7 +195,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -195,7 +195,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class ReshapeKernel { class ReshapeKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -228,15 +227,12 @@ class ReshapeKernel { ...@@ -228,15 +227,12 @@ class ReshapeKernel {
"sequence_reshape op."); "sequence_reshape op.");
} }
if (in->data<T>() != out->mutable_data(ctx.GetPlace(), in->type());
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) { framework::TensorCopySync(*in, ctx.GetPlace(), out);
framework::TensorCopySync(*in, ctx.GetPlace(), out);
}
out->Resize(out_dims); out->Resize(out_dims);
} }
}; };
template <typename T>
class ReshapeGradKernel { class ReshapeGradKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -244,9 +240,8 @@ class ReshapeGradKernel { ...@@ -244,9 +240,8 @@ class ReshapeGradKernel {
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) { d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
}
d_x->Resize(in_dims); d_x->Resize(in_dims);
} }
}; };
...@@ -341,46 +336,38 @@ namespace ops = paddle::operators; ...@@ -341,46 +336,38 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
double, ops::ReshapeKernel<double>, int, ops::ReshapeKernel, int, ops::ReshapeKernel,
ops::ReshapeKernel<int>, int64_t, int64_t, ops::ReshapeKernel);
ops::ReshapeKernel<int64_t>); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel<float>, double, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel<double>, int, ops::ReshapeGradKernel);
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker); ops::Reshape2GradMaker);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
double, ops::ReshapeKernel<double>, int, ops::ReshapeKernel, int, ops::ReshapeKernel,
ops::ReshapeKernel<int>, int64_t, int64_t, ops::ReshapeKernel);
ops::ReshapeKernel<int64_t>); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel<float>, double, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel<double>, int, ops::ReshapeGradKernel);
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
double, ops::ReshapeKernel<double>, int, ops::ReshapeKernel, int, ops::ReshapeKernel,
ops::ReshapeKernel<int>, int64_t, int64_t, ops::ReshapeKernel);
ops::ReshapeKernel<int64_t>); REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel<float>, double, ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel<double>, int, ops::ReshapeGradKernel);
ops::ReshapeGradKernel<int>, int64_t, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeGradKernel<int64_t>); ops::ReshapeKernel, int, ops::ReshapeKernel,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>, int64_t, ops::ReshapeKernel);
double, ops::ReshapeKernel<double>, int, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
ops::ReshapeKernel<int>, int64_t, double, ops::ReshapeGradKernel, int,
ops::ReshapeKernel<int64_t>); ops::ReshapeGradKernel, int64_t,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel);
ops::ReshapeGradKernel<float>, double,
ops::ReshapeGradKernel<double>, int,
ops::ReshapeGradKernel<int>, int64_t,
ops::ReshapeGradKernel<int64_t>);
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册