未验证 提交 2d24f56a 编写于 作者: Z Zhang Ting 提交者: GitHub

avoid data transfer, test=develop (#25810)

上级 f40a50d1
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/range_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -65,6 +66,13 @@ class RangeOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Out", {-1});
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return expected_kernel_type;
}
};
class RangeOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -18,7 +18,7 @@ from six.moves import reduce
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..initializer import Initializer
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard
from ..framework import Variable
from ..initializer import Constant
from ..core import VarDesc
......@@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
with device_guard("cpu"):
start = fill_constant([1], dtype, start)
elif start.dtype != dtype:
start = cast(start, dtype)
if not isinstance(end, Variable):
end = fill_constant([1], dtype, end)
with device_guard("cpu"):
end = fill_constant([1], dtype, end)
elif end.dtype != dtype:
end = cast(end, dtype)
if not isinstance(step, Variable):
step = fill_constant([1], dtype, step)
with device_guard("cpu"):
step = fill_constant([1], dtype, step)
elif step.dtype != dtype:
step = cast(step, dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册