未验证 提交 6af2729e 编写于 作者: C crystal 提交者: GitHub

【phi】migrate gather_tree,reduce_prod to phi (#39844)

* move to phi

* migrate gather_tree_op into phi

* move reduce_prod tp phi

* optimize code
上级 1db188f3
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_tree_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -73,5 +73,3 @@ selected ids.
namespace ops = paddle::operators;
REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker);
REGISTER_OP_CPU_KERNEL(gather_tree, ops::GatherTreeOpKernel<int32_t>,
ops::GatherTreeOpKernel<int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_tree_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void GatherTree(const T *ids_data, const T *parents_data,
T *out_data, const int64_t max_length,
const int64_t batch_size, const int64_t beam_size) {
CUDA_KERNEL_LOOP(i, batch_size * beam_size) {
int batch = i / beam_size;
int beam = i % beam_size;
auto idx =
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}
template <typename T>
class GatherTreeOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids = ctx.Input<Tensor>("Ids");
auto *parents = ctx.Input<Tensor>("Parents");
auto *out = ctx.Output<Tensor>("Out");
const auto *ids_data = ids->data<T>();
const auto *parents_data = parents->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(
ids_data, platform::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data, platform::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
auto &ids_dims = ids->dims();
int64_t max_length = ids_dims[0];
int64_t batch_size = ids_dims[1];
int64_t beam_size = ids_dims[2];
auto &dev_ctx = ctx.cuda_device_context();
const int block = 512;
int max_threads =
std::min(static_cast<int64_t>(dev_ctx.GetMaxPhysicalThreadCount()),
batch_size * beam_size);
const int grid = std::max(max_threads / block, 1);
GatherTree<<<grid, block>>>(ids_data, parents_data, out_data, max_length,
batch_size, beam_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(gather_tree, ops::GatherTreeOpCUDAKernel<int32_t>,
ops::GatherTreeOpCUDAKernel<int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class GatherTreeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids = ctx.Input<Tensor>("Ids");
auto *parents = ctx.Input<Tensor>("Parents");
auto *out = ctx.Output<Tensor>("Out");
const auto *ids_data = ids->data<T>();
const auto *parents_data = parents->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto &ids_dims = ids->dims();
auto max_length = ids_dims[0];
auto batch_size = ids_dims[1];
auto beam_size = ids_dims[2];
PADDLE_ENFORCE_NOT_NULL(
ids_data, platform::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data, platform::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
for (int batch = 0; batch < batch_size; batch++) {
for (int beam = 0; beam < beam_size; beam++) {
auto idx = (max_length - 1) * batch_size * beam_size +
batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -27,15 +27,7 @@ class CPUDeviceContext;
} // namespace paddle
REGISTER_REDUCE_OP(reduce_prod);
REGISTER_OP_CPU_KERNEL(reduce_prod,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int, ops::ProdFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::ProdFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_prod_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::ProdGradFunctor>,
......
......@@ -19,13 +19,6 @@
namespace paddle {
namespace operators {
struct ProdFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->prod(dim);
}
};
struct ProdGradFunctor {
template <typename DeviceContext, typename X, typename Y, typename DX,
typename DY, typename Dim>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_tree_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GatherTreeKernel(const Context &dev_ctx,
const DenseTensor &ids,
const DenseTensor &parents,
DenseTensor *out) {
const auto *ids_data = ids.data<T>();
const auto *parents_data = parents.data<T>();
T *out_data = dev_ctx.template Alloc<T>(out);
auto &ids_dims = ids.dims();
auto max_length = ids_dims[0];
auto batch_size = ids_dims[1];
auto beam_size = ids_dims[2];
PADDLE_ENFORCE_NOT_NULL(ids_data,
phi::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data,
phi::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
for (int batch = 0; batch < batch_size; batch++) {
for (int beam = 0; beam < beam_size; beam++) {
auto idx =
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
gather_tree, CPU, ALL_LAYOUT, phi::GatherTreeKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_prod_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
template <typename T, typename Context>
void ReduceProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(reduce_prod,
CPU,
ALL_LAYOUT,
phi::ReduceProdKernel,
float,
double,
int,
int64_t) {}
......@@ -33,5 +33,13 @@ struct MeanFunctor {
}
};
//////// Prod Functor ///////
struct ProdFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->prod(dim);
}
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
#pragma once
REGISTER_OP_CUDA_KERNEL(
reduce_prod,
ops::ReduceCudaKernel<float, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MulFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MulFunctor, kps::IdentityFunctor>);
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GatherTreeKernel(const Context &dev_ctx,
const DenseTensor &ids,
const DenseTensor &parents,
DenseTensor *out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gather_tree_kernel.h"
namespace phi {
template <typename T>
__global__ void GatherTree(const T *ids_data,
const T *parents_data,
T *out_data,
const int64_t max_length,
const int64_t batch_size,
const int64_t beam_size) {
CUDA_KERNEL_LOOP(i, batch_size * beam_size) {
int batch = i / beam_size;
int beam = i % beam_size;
auto idx =
(max_length - 1) * batch_size * beam_size + batch * beam_size + beam;
out_data[idx] = ids_data[idx];
auto parent = parents_data[idx];
for (int step = max_length - 2; step >= 0; step--) {
idx = step * batch_size * beam_size + batch * beam_size;
out_data[idx + beam] = ids_data[idx + parent];
parent = parents_data[idx + parent];
}
}
}
template <typename T, typename Context>
void GatherTreeKernel(const Context &dev_ctx,
const DenseTensor &ids,
const DenseTensor &parents,
DenseTensor *out) {
const auto *ids_data = ids.data<T>();
const auto *parents_data = parents.data<T>();
T *out_data = dev_ctx.template Alloc<T>(out);
PADDLE_ENFORCE_NOT_NULL(ids_data,
phi::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data,
phi::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
auto &ids_dims = ids.dims();
int64_t max_length = ids_dims[0];
int64_t batch_size = ids_dims[1];
int64_t beam_size = ids_dims[2];
const int block = 512;
int max_threads =
std::min(static_cast<int64_t>(dev_ctx.GetMaxPhysicalThreadCount()),
batch_size * beam_size);
const int grid = std::max(max_threads / block, 1);
GatherTree<<<grid, block>>>(
ids_data, parents_data, out_data, max_length, batch_size, beam_size);
}
} // namespace phi
PD_REGISTER_KERNEL(
gather_tree, GPU, ALL_LAYOUT, phi::GatherTreeKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_prod_kernel.h"
namespace phi {
template <typename T, typename Context>
void ReduceProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(reduce_prod,
GPU,
ALL_LAYOUT,
phi::ReduceProdKernel,
float,
double,
int,
int64_t) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ReduceProdKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
} // namespace phi
......@@ -51,6 +51,11 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"reduce_prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
......@@ -58,3 +63,4 @@ PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);
PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册