From a972c33fd7b93a24cc199ad4f3ae01ea371d3972 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 16 Nov 2020 19:33:33 +0800 Subject: [PATCH] refine gather OP performance for dynamic mode (#28587) --- paddle/fluid/operators/gather_op.cc | 9 +++++++++ python/paddle/tensor/manipulation.py | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 648afe7e821..162766546b3 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -93,6 +93,15 @@ class GatherGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.device_context()); } + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "Axis") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class GatherOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4a01f7e7fa3..adad9cfdc26 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -785,9 +785,12 @@ def gather(x, index, axis=None, name=None): if axis is None: axis = 0 axis_tensor = axis + if not isinstance(axis, Variable) and axis == 0: + return paddle.fluid.layers.gather(input=x, index=index, overwrite=True) if not isinstance(axis, Variable): with device_guard("cpu"): - axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) + axis_tensor = fill_constant( + shape=[1], dtype='int64', value=axis, force_cpu=True) if in_dygraph_mode(): return core.ops.gather(x, index, axis_tensor) -- GitLab