From c00f82784390f7066d900b67122b47bb0b77ad92 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 30 Jun 2020 20:32:45 +0800 Subject: [PATCH] Avoid data transforming ShapeTensor from CPU to GPU in fill_constant op. (#25267) --- paddle/fluid/operators/fill_constant_op.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 36873f16808..35d54577bfe 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -51,6 +51,17 @@ class FillConstantOp : public framework::OperatorWithKernel { } protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( -- GitLab