未验证 提交 01a96463 编写于 作者: Z Zhang Ting 提交者: GitHub

optimize assign op to avoid copy data from GPU to GPU (#21181)

* optimize assign op to avoid copy data from GPU to GPU, test=develop

* modified GetkernelTypeForVar and just avoid device transform, test=develop
上级 c91cb6c5
......@@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
......
......@@ -47,7 +47,7 @@ class AssignFunctor {
out_rows.set_height(rows.height());
auto &t = rows.value();
auto *m = out_rows.mutable_value();
framework::TensorCopy(t, t.place(), dev_ctx_, m);
framework::TensorCopy(t, dev_ctx_.GetPlace(), dev_ctx_, m);
}
template <typename T>
......@@ -60,7 +60,7 @@ class AssignFunctor {
framework::LoDTensor *out) const {
if (lod_tensor.numel() == 0) return;
auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册