未验证 提交 01a96463 编写于 作者: Z Zhang Ting 提交者: GitHub

optimize assign op to avoid copy data from GPU to GPU (#21181)

* optimize assign op to avoid copy data from GPU to GPU, test=develop

* modified GetkernelTypeForVar and just avoid device transform, test=develop
上级 c91cb6c5
...@@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
......
...@@ -47,7 +47,7 @@ class AssignFunctor { ...@@ -47,7 +47,7 @@ class AssignFunctor {
out_rows.set_height(rows.height()); out_rows.set_height(rows.height());
auto &t = rows.value(); auto &t = rows.value();
auto *m = out_rows.mutable_value(); auto *m = out_rows.mutable_value();
framework::TensorCopy(t, t.place(), dev_ctx_, m); framework::TensorCopy(t, dev_ctx_.GetPlace(), dev_ctx_, m);
} }
template <typename T> template <typename T>
...@@ -60,7 +60,7 @@ class AssignFunctor { ...@@ -60,7 +60,7 @@ class AssignFunctor {
framework::LoDTensor *out) const { framework::LoDTensor *out) const {
if (lod_tensor.numel() == 0) return; if (lod_tensor.numel() == 0) return;
auto &out_tensor = *out; auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod()); out_tensor.set_lod(lod_tensor.lod());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册