提交 25ecd206 编写于 作者: X xuwei06

Use CopyFromVector for assign_value_op

上级 237385cf
...@@ -116,8 +116,8 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place, ...@@ -116,8 +116,8 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place,
* @param[in] src The external tensor. * @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources. * @param[in] ctx The device context contains device resources.
* *
* * @note CopyFromVector assumes that the tensor has been resized * * @note CopyFromVector will resize dst to an 1D tensor with the same
* before invoking. * size as src.
*/ */
template <typename T> template <typename T>
inline void CopyFromVector(const std::vector<T>& src, inline void CopyFromVector(const std::vector<T>& src,
......
...@@ -63,20 +63,11 @@ $$Out = values$$ ...@@ -63,20 +63,11 @@ $$Out = values$$
} }
}; };
template <typename T>
class AssignValueCPUKernel : public AssignValueKernel<T> {
protected:
virtual void Copy(void *dst, const void *src, size_t size,
const framework::ExecutionContext &ctx) const {
std::memcpy(dst, src, size);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker); REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueCPUKernel<int>, REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueCPUKernel<float>) ops::AssignValueKernel<float>);
...@@ -14,23 +14,6 @@ limitations under the License. */ ...@@ -14,23 +14,6 @@ limitations under the License. */
#include "paddle/operators/assign_value_op.h" #include "paddle/operators/assign_value_op.h"
namespace paddle {
namespace operators {
template <typename T>
class AssignValueGPUKernel : public AssignValueKernel<T> {
protected:
virtual void Copy(void* dst, const void* src, size_t size,
const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
paddle::platform::GpuMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice,
dev_ctx.stream());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueGPUKernel<int>, REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueGPUKernel<float>); ops::AssignValueKernel<float>);
...@@ -27,8 +27,6 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -27,8 +27,6 @@ class AssignValueKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int>>("shape");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
out->Resize(framework::make_ddim(shape));
auto* dst = out->mutable_data<T>(ctx.GetPlace());
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr; const char* value_name = nullptr;
switch (dtype) { switch (dtype) {
...@@ -43,12 +41,9 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -43,12 +41,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
break; break;
} }
auto values = ctx.Attr<std::vector<T>>(value_name); auto values = ctx.Attr<std::vector<T>>(value_name);
Copy(dst, values.data(), sizeof(T) * values.size(), ctx); framework::CopyFromVector(values, ctx.device_context(), out);
out->Resize(framework::make_ddim(shape));
} }
protected:
virtual void Copy(void* dst, const void* src, size_t size,
const framework::ExecutionContext& ctx) const = 0;
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册