未验证 提交 95122ebe 编写于 作者: Y Yiqun Liu 提交者: GitHub

Advoid CPU -> CPU memory copy when start, end, step is already on CPU. (#29088)

上级 d815fbf9
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/range_op.h" #include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
...@@ -34,26 +35,9 @@ class CUDARangeKernel : public framework::OpKernel<T> { ...@@ -34,26 +35,9 @@ class CUDARangeKernel : public framework::OpKernel<T> {
auto* step_t = context.Input<framework::Tensor>("Step"); auto* step_t = context.Input<framework::Tensor>("Step");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
T start, end, step; T start = GetValue<T>(start_t);
framework::Tensor n; T end = GetValue<T>(end_t);
if (::paddle::platform::is_cpu_place(start_t->place())) { T step = GetValue<T>(step_t);
start = start_t->data<T>()[0];
} else {
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
start = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(end_t->place())) {
end = end_t->data<T>()[0];
} else {
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
end = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(step_t->place())) {
step = step_t->data<T>()[0];
} else {
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
step = n.data<T>()[0];
}
int64_t size = 0; int64_t size = 0;
GetSize(start, end, step, &size); GetSize(start, end, step, &size);
......
...@@ -108,5 +108,18 @@ inline framework::DDim GetShape(const framework::ExecutionContext& ctx) { ...@@ -108,5 +108,18 @@ inline framework::DDim GetShape(const framework::ExecutionContext& ctx) {
return framework::make_ddim(vec_shape); return framework::make_ddim(vec_shape);
} }
template <typename T>
inline T GetValue(const framework::Tensor* x) {
T value = static_cast<T>(0);
if (!platform::is_cpu_place(x->place())) {
framework::Tensor cpu_x;
framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x->data<T>()[0];
}
return value;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册