未验证 提交 9ed6c895 编写于 作者: J jiangcheng 提交者: GitHub

optimize range op by place parameters on cpu rather than gpu, test=develop (#30811)

上级 3789a699
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/range_op.h" #include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -33,13 +34,26 @@ class CUDARangeKernel : public framework::OpKernel<T> { ...@@ -33,13 +34,26 @@ class CUDARangeKernel : public framework::OpKernel<T> {
auto* step_t = context.Input<framework::Tensor>("Step"); auto* step_t = context.Input<framework::Tensor>("Step");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
T start, end, step;
framework::Tensor n; framework::Tensor n;
if (::paddle::platform::is_cpu_place(start_t->place())) {
start = start_t->data<T>()[0];
} else {
framework::TensorCopy(*start_t, platform::CPUPlace(), &n); framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0]; start = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(end_t->place())) {
end = end_t->data<T>()[0];
} else {
framework::TensorCopy(*end_t, platform::CPUPlace(), &n); framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0]; end = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(step_t->place())) {
step = step_t->data<T>()[0];
} else {
framework::TensorCopy(*step_t, platform::CPUPlace(), &n); framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0]; step = n.data<T>()[0];
}
int64_t size = 0; int64_t size = 0;
GetSize(start, end, step, &size); GetSize(start, end, step, &size);
...@@ -47,7 +61,7 @@ class CUDARangeKernel : public framework::OpKernel<T> { ...@@ -47,7 +61,7 @@ class CUDARangeKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
int block = 512; int block = std::min(size, static_cast<int64_t>(256));
int grid = (size + block - 1) / block; int grid = (size + block - 1) / block;
RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data); RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data);
} }
......
...@@ -1374,19 +1374,19 @@ def range(start, end, step, dtype, name=None): ...@@ -1374,19 +1374,19 @@ def range(start, end, step, dtype, name=None):
if not isinstance(start, Variable): if not isinstance(start, Variable):
with device_guard("cpu"): with device_guard("cpu"):
start = fill_constant([1], dtype, start) start = fill_constant([1], dtype, start, force_cpu=True)
elif start.dtype != dtype: elif start.dtype != dtype:
start = cast(start, dtype) start = cast(start, dtype)
if not isinstance(end, Variable): if not isinstance(end, Variable):
with device_guard("cpu"): with device_guard("cpu"):
end = fill_constant([1], dtype, end) end = fill_constant([1], dtype, end, force_cpu=True)
elif end.dtype != dtype: elif end.dtype != dtype:
end = cast(end, dtype) end = cast(end, dtype)
if not isinstance(step, Variable): if not isinstance(step, Variable):
with device_guard("cpu"): with device_guard("cpu"):
step = fill_constant([1], dtype, step) step = fill_constant([1], dtype, step, force_cpu=True)
elif step.dtype != dtype: elif step.dtype != dtype:
step = cast(step, dtype) step = cast(step, dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册