提交 25ecd206 编写于 作者: X xuwei06

Use CopyFromVector for assign_value_op

上级 237385cf
......@@ -116,8 +116,8 @@ inline void Copy(const Tensor& src, const platform::Place& dst_place,
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
* * @note CopyFromVector will resize dst to an 1D tensor with the same
* size as src.
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src,
......
......@@ -63,20 +63,11 @@ $$Out = values$$
}
};
template <typename T>
class AssignValueCPUKernel : public AssignValueKernel<T> {
protected:
virtual void Copy(void *dst, const void *src, size_t size,
const framework::ExecutionContext &ctx) const {
std::memcpy(dst, src, size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueCPUKernel<int>,
ops::AssignValueCPUKernel<float>)
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
......@@ -14,23 +14,6 @@ limitations under the License. */
#include "paddle/operators/assign_value_op.h"
namespace paddle {
namespace operators {
template <typename T>
class AssignValueGPUKernel : public AssignValueKernel<T> {
protected:
virtual void Copy(void* dst, const void* src, size_t size,
const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
paddle::platform::GpuMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice,
dev_ctx.stream());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueGPUKernel<int>,
ops::AssignValueGPUKernel<float>);
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
......@@ -27,8 +27,6 @@ class AssignValueKernel : public framework::OpKernel<T> {
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto shape = ctx.Attr<std::vector<int>>("shape");
auto* out = ctx.Output<framework::Tensor>("Out");
out->Resize(framework::make_ddim(shape));
auto* dst = out->mutable_data<T>(ctx.GetPlace());
int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr;
switch (dtype) {
......@@ -43,12 +41,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
break;
}
auto values = ctx.Attr<std::vector<T>>(value_name);
Copy(dst, values.data(), sizeof(T) * values.size(), ctx);
framework::CopyFromVector(values, ctx.device_context(), out);
out->Resize(framework::make_ddim(shape));
}
protected:
virtual void Copy(void* dst, const void* src, size_t size,
const framework::ExecutionContext& ctx) const = 0;
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册