From 25ecd2061ae13656653fbd89c2216375e0cd9e55 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Thu, 11 Jan 2018 12:52:54 -0800 Subject: [PATCH] Use CopyFromVector for assign_value_op --- paddle/framework/tensor_util.h | 4 ++-- paddle/operators/assign_value_op.cc | 13 ++----------- paddle/operators/assign_value_op.cu.cc | 21 ++------------------- paddle/operators/assign_value_op.h | 9 ++------- 4 files changed, 8 insertions(+), 39 deletions(-) diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index f541d2ba6..091b63bf0 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -116,8 +116,8 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place, * @param[in] src The external tensor. * @param[in] ctx The device context contains device resources. * - * * @note CopyFromVector assumes that the tensor has been resized - * before invoking. + * * @note CopyFromVector will resize dst to an 1D tensor with the same + * size as src. */ template inline void CopyFromVector(const std::vector& src, diff --git a/paddle/operators/assign_value_op.cc b/paddle/operators/assign_value_op.cc index 7f6d351f5..d5671c118 100644 --- a/paddle/operators/assign_value_op.cc +++ b/paddle/operators/assign_value_op.cc @@ -63,20 +63,11 @@ $$Out = values$$ } }; -template -class AssignValueCPUKernel : public AssignValueKernel { - protected: - virtual void Copy(void *dst, const void *src, size_t size, - const framework::ExecutionContext &ctx) const { - std::memcpy(dst, src, size); - } -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker); -REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueCPUKernel, - ops::AssignValueCPUKernel) +REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/operators/assign_value_op.cu.cc b/paddle/operators/assign_value_op.cu.cc index 8afb03203..b17e20150 100644 --- a/paddle/operators/assign_value_op.cu.cc +++ b/paddle/operators/assign_value_op.cu.cc @@ -14,23 +14,6 @@ limitations under the License. */ #include "paddle/operators/assign_value_op.h" -namespace paddle { -namespace operators { - -template -class AssignValueGPUKernel : public AssignValueKernel { - protected: - virtual void Copy(void* dst, const void* src, size_t size, - const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - paddle::platform::GpuMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, - dev_ctx.stream()); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueGPUKernel, - ops::AssignValueGPUKernel); +REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel, + ops::AssignValueKernel); diff --git a/paddle/operators/assign_value_op.h b/paddle/operators/assign_value_op.h index bdb5bce27..db2e43077 100644 --- a/paddle/operators/assign_value_op.h +++ b/paddle/operators/assign_value_op.h @@ -27,8 +27,6 @@ class AssignValueKernel : public framework::OpKernel { virtual void Compute(const framework::ExecutionContext& ctx) const { auto shape = ctx.Attr>("shape"); auto* out = ctx.Output("Out"); - out->Resize(framework::make_ddim(shape)); - auto* dst = out->mutable_data(ctx.GetPlace()); int dtype = ctx.Attr("dtype"); const char* value_name = nullptr; switch (dtype) { @@ -43,12 +41,9 @@ class AssignValueKernel : public framework::OpKernel { break; } auto values = ctx.Attr>(value_name); - Copy(dst, values.data(), sizeof(T) * values.size(), ctx); + framework::CopyFromVector(values, ctx.device_context(), out); + out->Resize(framework::make_ddim(shape)); } - - protected: - virtual void Copy(void* dst, const void* src, size_t size, - const framework::ExecutionContext& ctx) const = 0; }; } // namespace operators -- GitLab