未验证 提交 c252b1de 编写于 作者: C Chen Weihang 提交者: GitHub

Simplify size op impl (#45808)

* simplify size op

* trans to cuda manuly

* fix copy error
上级 7d000112
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
PD_REGISTER_KERNEL(size,
CPU,
ALL_LAYOUT,
phi::SizeKernel,
uint8_t,
int16_t,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
PD_REGISTER_KERNEL(size,
GPU,
ALL_LAYOUT,
phi::SizeKernel,
int16_t,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
...@@ -12,28 +12,33 @@ ...@@ -12,28 +12,33 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename Context>
void SizeKernel(const Context& ctx, void SizeKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
DenseTensor* out) { DenseTensor* out) {
auto place = ctx.GetPlace(); auto* out_data = ctx.template HostAlloc<int64_t>(out);
auto out_data = ctx.template Alloc<int64_t>(out);
auto cpu_place = phi::CPUPlace();
if (place == cpu_place) {
out_data[0] = input.numel(); out_data[0] = input.numel();
} else {
DenseTensor cpu_tensor;
cpu_tensor.Resize(out->dims());
auto cpu_data = ctx.template HostAlloc<int64_t>(&cpu_tensor);
cpu_data[0] = input.numel();
phi::Copy(ctx, cpu_tensor, place, false, out);
}
} }
} // namespace phi } // namespace phi
PD_REGISTER_GENERAL_KERNEL(
size, CPU, ALL_LAYOUT, phi::SizeKernel<phi::CPUContext>, ALL_DTYPE) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
size, GPU, ALL_LAYOUT, phi::SizeKernel<phi::GPUContext>, ALL_DTYPE) {
kernel->OutputAt(0)
.SetBackend(phi::Backend::CPU)
.SetDataType(phi::DataType::INT64);
}
#endif
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename Context>
void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out); void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -1140,6 +1140,9 @@ def all_gather_object(object_list, obj, group=None): ...@@ -1140,6 +1140,9 @@ def all_gather_object(object_list, obj, group=None):
), "all_gather_object doesn't support static graph mode." ), "all_gather_object doesn't support static graph mode."
tensor, len_of_tensor = _convert_object_to_tensor(obj) tensor, len_of_tensor = _convert_object_to_tensor(obj)
if paddle.get_device() != "cpu":
len_of_tensor = len_of_tensor._copy_to(
paddle.framework._current_expected_place(), False)
# gather len_of_tensor from all ranks # gather len_of_tensor from all ranks
list_len_of_tensor = [] list_len_of_tensor = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册