From 714b0076b6d4af692a4aa483b15208f984061f8e Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Tue, 17 Mar 2020 10:05:50 +0800 Subject: [PATCH] Override GetKernelTypeForVar to avoid device transform, test=develop (#23032) --- paddle/fluid/operators/shape_op.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index 4a0f41ae54d..edc538c5056 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" +#include #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -30,6 +31,15 @@ class ShapeOp : public framework::OperatorWithKernel { auto in_dim = ctx->GetInputDim("Input"); ctx->SetOutputDim("Out", {in_dim.size()}); } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } }; class ShapeOpMaker : public framework::OpProtoAndCheckerMaker { -- GitLab