未验证 提交 a972c33f 编写于 作者: W wangchaochaohu 提交者: GitHub

refine gather OP performance for dynamic mode (#28587)

上级 ece1e4cd
...@@ -93,6 +93,15 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -93,6 +93,15 @@ class GatherGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "Axis") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -785,9 +785,12 @@ def gather(x, index, axis=None, name=None): ...@@ -785,9 +785,12 @@ def gather(x, index, axis=None, name=None):
if axis is None: if axis is None:
axis = 0 axis = 0
axis_tensor = axis axis_tensor = axis
if not isinstance(axis, Variable) and axis == 0:
return paddle.fluid.layers.gather(input=x, index=index, overwrite=True)
if not isinstance(axis, Variable): if not isinstance(axis, Variable):
with device_guard("cpu"): with device_guard("cpu"):
axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) axis_tensor = fill_constant(
shape=[1], dtype='int64', value=axis, force_cpu=True)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor) return core.ops.gather(x, index, axis_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册