未验证 提交 c00f8278 编写于 作者: Y Yiqun Liu 提交者: GitHub

Avoid data transforming ShapeTensor from CPU to GPU in fill_constant op. (#25267)

上级 5e8e6dad
...@@ -51,6 +51,17 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -51,6 +51,17 @@ class FillConstantOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册