未验证 提交 a972c33f 编写于 作者: W wangchaochaohu 提交者: GitHub

refine gather OP performance for dynamic mode (#28587)

上级 ece1e4cd
......@@ -93,6 +93,15 @@ class GatherGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "Axis") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -785,9 +785,12 @@ def gather(x, index, axis=None, name=None):
if axis is None:
axis = 0
axis_tensor = axis
if not isinstance(axis, Variable) and axis == 0:
return paddle.fluid.layers.gather(input=x, index=index, overwrite=True)
if not isinstance(axis, Variable):
with device_guard("cpu"):
axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis)
axis_tensor = fill_constant(
shape=[1], dtype='int64', value=axis, force_cpu=True)
if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册