diff --git a/paddle/fluid/operators/range_op.cu b/paddle/fluid/operators/range_op.cu index f2c78e0f70b321814b890d3a0b6e6dffb7cc689c..6250d68730e138f30cf14d664e51bbe7a506dbc2 100644 --- a/paddle/fluid/operators/range_op.cu +++ b/paddle/fluid/operators/range_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/range_op.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { @@ -34,26 +35,9 @@ class CUDARangeKernel : public framework::OpKernel { auto* step_t = context.Input("Step"); auto* out = context.Output("Out"); - T start, end, step; - framework::Tensor n; - if (::paddle::platform::is_cpu_place(start_t->place())) { - start = start_t->data()[0]; - } else { - framework::TensorCopy(*start_t, platform::CPUPlace(), &n); - start = n.data()[0]; - } - if (::paddle::platform::is_cpu_place(end_t->place())) { - end = end_t->data()[0]; - } else { - framework::TensorCopy(*end_t, platform::CPUPlace(), &n); - end = n.data()[0]; - } - if (::paddle::platform::is_cpu_place(step_t->place())) { - step = step_t->data()[0]; - } else { - framework::TensorCopy(*step_t, platform::CPUPlace(), &n); - step = n.data()[0]; - } + T start = GetValue(start_t); + T end = GetValue(end_t); + T step = GetValue(step_t); int64_t size = 0; GetSize(start, end, step, &size); diff --git a/paddle/fluid/operators/utils.h b/paddle/fluid/operators/utils.h index 985c35127617bf1c4c708c3ab741ff8ca058af8a..912d538d5e9513bc0f87b5b4593468bf4f138fad 100644 --- a/paddle/fluid/operators/utils.h +++ b/paddle/fluid/operators/utils.h @@ -108,5 +108,18 @@ inline framework::DDim GetShape(const framework::ExecutionContext& ctx) { return framework::make_ddim(vec_shape); } +template +inline T GetValue(const framework::Tensor* x) { + T value = static_cast(0); + if (!platform::is_cpu_place(x->place())) { + framework::Tensor cpu_x; + framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x); + value = cpu_x.data()[0]; + } else { + value = x->data()[0]; + } + return value; +} + } // namespace operators } // namespace paddle