未验证 提交 2d24f56a 编写于 作者: Z Zhang Ting 提交者: GitHub

avoid data transfer, test=develop (#25810)

上级 f40a50d1
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/range_op.h" #include "paddle/fluid/operators/range_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -65,6 +66,13 @@ class RangeOp : public framework::OperatorWithKernel { ...@@ -65,6 +66,13 @@ class RangeOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", {-1}); ctx->SetOutputDim("Out", {-1});
} }
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return expected_kernel_type;
}
}; };
class RangeOpMaker : public framework::OpProtoAndCheckerMaker { class RangeOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -18,7 +18,7 @@ from six.moves import reduce ...@@ -18,7 +18,7 @@ from six.moves import reduce
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Initializer from ..initializer import Initializer
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant from ..initializer import Constant
from ..core import VarDesc from ..core import VarDesc
...@@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None): ...@@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable): if not isinstance(start, Variable):
start = fill_constant([1], dtype, start) with device_guard("cpu"):
start = fill_constant([1], dtype, start)
elif start.dtype != dtype: elif start.dtype != dtype:
start = cast(start, dtype) start = cast(start, dtype)
if not isinstance(end, Variable): if not isinstance(end, Variable):
end = fill_constant([1], dtype, end) with device_guard("cpu"):
end = fill_constant([1], dtype, end)
elif end.dtype != dtype: elif end.dtype != dtype:
end = cast(end, dtype) end = cast(end, dtype)
if not isinstance(step, Variable): if not isinstance(step, Variable):
step = fill_constant([1], dtype, step) with device_guard("cpu"):
step = fill_constant([1], dtype, step)
elif step.dtype != dtype: elif step.dtype != dtype:
step = cast(step, dtype) step = cast(step, dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册