未验证 提交 9ed6c895 编写于 作者: J jiangcheng 提交者: GitHub

optimize range op by place parameters on cpu rather than gpu, test=develop (#30811)

上级 3789a699
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -33,13 +34,26 @@ class CUDARangeKernel : public framework::OpKernel<T> {
auto* step_t = context.Input<framework::Tensor>("Step");
auto* out = context.Output<framework::Tensor>("Out");
T start, end, step;
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0];
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0];
if (::paddle::platform::is_cpu_place(start_t->place())) {
start = start_t->data<T>()[0];
} else {
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
start = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(end_t->place())) {
end = end_t->data<T>()[0];
} else {
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
end = n.data<T>()[0];
}
if (::paddle::platform::is_cpu_place(step_t->place())) {
step = step_t->data<T>()[0];
} else {
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
step = n.data<T>()[0];
}
int64_t size = 0;
GetSize(start, end, step, &size);
......@@ -47,7 +61,7 @@ class CUDARangeKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
int block = 512;
int block = std::min(size, static_cast<int64_t>(256));
int grid = (size + block - 1) / block;
RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data);
}
......
......@@ -1374,19 +1374,19 @@ def range(start, end, step, dtype, name=None):
if not isinstance(start, Variable):
with device_guard("cpu"):
start = fill_constant([1], dtype, start)
start = fill_constant([1], dtype, start, force_cpu=True)
elif start.dtype != dtype:
start = cast(start, dtype)
if not isinstance(end, Variable):
with device_guard("cpu"):
end = fill_constant([1], dtype, end)
end = fill_constant([1], dtype, end, force_cpu=True)
elif end.dtype != dtype:
end = cast(end, dtype)
if not isinstance(step, Variable):
with device_guard("cpu"):
step = fill_constant([1], dtype, step)
step = fill_constant([1], dtype, step, force_cpu=True)
elif step.dtype != dtype:
step = cast(step, dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册