未验证 提交 282cba48 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate tile_op into Phi (#40371)

* [Phi] Migrate tile_op into Phi

* fix tile_sig

* fix include headers

* fix using
上级 34d4b40d
......@@ -12,11 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tile_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -26,66 +30,6 @@ class TileOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Tile");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Tile");
auto x_dims = ctx->GetInputDim("X");
auto repeat_times = ctx->Attrs().Get<std::vector<int>>("repeat_times");
if (repeat_times.size() == 0) {
repeat_times = std::vector<int>(x_dims.size(), -1);
}
PADDLE_ENFORCE_LE(
x_dims.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED, x_dims.size()));
PADDLE_ENFORCE_LE(
repeat_times.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED, repeat_times.size()));
PADDLE_ENFORCE_GE(
repeat_times.size(), 1,
platform::errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must be positive integers, but the value received is %d.",
repeat_times.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), repeat_times.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = phi::vectorize<int>(x_dims);
if (x_dim_vec.size() > repeat_times.size()) {
auto diff = x_dim_vec.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, -1);
} else {
auto diff = repeat_times.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
}
for (size_t i = 0; i < repeat_times.size(); ++i) {
if (x_dim_vec[i] == -1 || repeat_times[i] == -1) {
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
repeat_times[i], 0,
platform::errors::InvalidArgument(
"Every element of the input 'repeat_times' for tile op must be "
"greater than 0, but the value given is %d.",
repeat_times[i]));
out_shape[i] = x_dim_vec[i] * repeat_times[i];
}
}
ctx->SetOutputDim("Out", phi::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -268,38 +212,15 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(TileGradNoNeedBufVarsInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(tile, TileInferMetaFunctor,
PD_INFER_META(phi::TileInferMeta));
REGISTER_OPERATOR(tile, ops::TileOp, ops::TileOpMaker,
ops::TileGradOpMaker<paddle::framework::OpDesc>,
ops::TileGradOpMaker<paddle::imperative::OpBase>);
ops::TileGradOpMaker<paddle::imperative::OpBase>,
TileInferMetaFunctor);
REGISTER_OPERATOR(tile_grad, ops::TileGradOp,
ops::TileDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::TileDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::TileGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
tile, ops::TileKernel<paddle::platform::CPUDeviceContext, float>,
ops::TileKernel<paddle::platform::CPUDeviceContext, double>,
ops::TileKernel<paddle::platform::CPUDeviceContext, int>,
ops::TileKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TileKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
tile_grad, ops::TileGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::TileGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
tile, ops::TileKernel<paddle::platform::CUDADeviceContext, float>,
ops::TileKernel<paddle::platform::CUDADeviceContext, double>,
ops::TileKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::TileKernel<paddle::platform::CUDADeviceContext, int>,
ops::TileKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TileKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
tile_grad, ops::TileGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::TileGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
namespace paddle {
namespace operators {
inline std::vector<int> get_repeat_times(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
auto* repeat_tensor = ctx.Input<framework::LoDTensor>("RepeatTimes");
auto* repeat_data = repeat_tensor->data<int>();
framework::Tensor cpu_repeat_tensor;
if (platform::is_gpu_place(repeat_tensor->place()) ||
platform::is_xpu_place(repeat_tensor->place()) ||
platform::is_npu_place(repeat_tensor->place())) {
paddle::framework::TensorCopySync(*repeat_tensor, platform::CPUPlace(),
&cpu_repeat_tensor);
repeat_data = cpu_repeat_tensor.data<int>();
}
auto vec_repeat_times =
std::vector<int>(repeat_data, repeat_data + repeat_tensor->numel());
return vec_repeat_times;
}
auto list_repeat_times_tensor =
ctx.MultiInput<framework::Tensor>("repeat_times_tensor");
if (list_repeat_times_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_repeat_times;
for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) {
auto tensor = list_repeat_times_tensor[i];
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place()) ||
platform::is_npu_place(tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_repeat_times.push_back(*temp.data<int32_t>());
} else {
vec_repeat_times.push_back(*tensor->data<int32_t>());
}
}
return vec_repeat_times;
} else {
return ctx.Attr<std::vector<int>>("repeat_times");
}
}
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::To32BitIndex;
template <typename DeviceContext, typename T>
class TileKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1, platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op must be a positive "
"integer, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
auto repeat_times = get_repeat_times(context);
int repeat_times_size = repeat_times.size();
PADDLE_ENFORCE_GE(
repeat_times_size, 1,
platform::errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile "
"op must be positive, but the value received is %d.",
repeat_times_size));
PADDLE_ENFORCE_LE(
repeat_times_size, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number of elements of the input 'repeat_times' for tile op "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, repeat_times_size));
rank = std::max(rank, repeat_times_size);
switch (rank) {
case 1:
Tile<1>(context);
break;
case 2:
Tile<2>(context);
break;
case 3:
Tile<3>(context);
break;
case 4:
Tile<4>(context);
break;
case 5:
Tile<5>(context);
break;
case 6:
Tile<6>(context);
break;
}
}
protected:
template <int Rank>
void Tile(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto repeat_times = get_repeat_times(context);
for (size_t i = 0; i < repeat_times.size(); ++i) {
PADDLE_ENFORCE_GT(
repeat_times[i], 0,
platform::errors::InvalidArgument(
"All elements of the input 'repeat_times' for tile op must "
"be positive integers, but the value received is %d.",
repeat_times[i]));
}
auto vec_in_dims = phi::vectorize<int>(in_dims);
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
}
PADDLE_ENFORCE_EQ(
repeat_times.size(), vec_in_dims.size(),
platform::errors::InvalidArgument(
"The rank (%d) of the input 'x' and the rank (%d) of the input "
"'repeat_times' for tile op must match after promotion.",
vec_in_dims.size(), repeat_times.size()));
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = phi::make_ddim(vec_in_dims);
framework::DDim out_dims(new_in_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x), bcast_dims);
} else {
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
template <typename DeviceContext, typename T>
class TileGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto repeat_times = get_repeat_times(context);
auto x_dims = x->dims();
auto vec_in_dims = phi::vectorize<int>(x_dims);
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
}
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*dout, context.GetPlace(), context.device_context(),
dx);
// TensorCopy may change the dims of dx
dx->Resize(x_dims);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"Th rank of the input 'Out@GRAD' for tile_grad op "
" must be greater than or equal to 1, but "
"the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for tile_grad op "
"must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) {
case 1:
TileBackward<1>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 2:
TileBackward<2>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 3:
TileBackward<3>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 4:
TileBackward<4>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 5:
TileBackward<5>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 6:
TileBackward<6>(context, reshape_dims_vec, reduce_dims_vec);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
protected:
template <int Dims>
void TileBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
namespace paddle {
namespace operators {
inline std::vector<int> get_repeat_times(
const framework::ExecutionContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
auto* repeat_tensor = ctx.Input<framework::LoDTensor>("RepeatTimes");
auto* repeat_data = repeat_tensor->data<int>();
framework::Tensor cpu_repeat_tensor;
if (platform::is_gpu_place(repeat_tensor->place()) ||
platform::is_xpu_place(repeat_tensor->place()) ||
platform::is_npu_place(repeat_tensor->place())) {
paddle::framework::TensorCopySync(*repeat_tensor, platform::CPUPlace(),
&cpu_repeat_tensor);
repeat_data = cpu_repeat_tensor.data<int>();
}
auto vec_repeat_times =
std::vector<int>(repeat_data, repeat_data + repeat_tensor->numel());
return vec_repeat_times;
}
auto list_repeat_times_tensor =
ctx.MultiInput<framework::Tensor>("repeat_times_tensor");
if (list_repeat_times_tensor.size() > 0) {
// get tensor from
std::vector<int> vec_repeat_times;
for (size_t i = 0; i < list_repeat_times_tensor.size(); ++i) {
auto tensor = list_repeat_times_tensor[i];
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place()) ||
platform::is_npu_place(tensor->place())) {
framework::Tensor temp;
paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_repeat_times.push_back(*temp.data<int32_t>());
} else {
vec_repeat_times.push_back(*tensor->data<int32_t>());
}
}
return vec_repeat_times;
} else {
return ctx.Attr<std::vector<int>>("repeat_times");
}
}
} // namespace operators
} // namespace paddle
......@@ -11,7 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/tile_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/tile_op_functor.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -11,11 +11,14 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/tile_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/tile_op_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class TileXPUKernel : public framework::OpKernel<T> {
public:
......
......@@ -395,6 +395,74 @@ void MultinomialInferMeta(const MetaTensor& x,
out->set_dtype(DataType::INT64);
}
void TileInferMeta(const MetaTensor& x,
const ScalarArray& repeat_times,
MetaTensor* out,
MetaConfig config) {
#define MAX_RANK_SUPPORTED 6
auto repeat_times_data = repeat_times.GetData();
auto x_dims = x.dims();
if (repeat_times_data.size() == 0) {
repeat_times_data = std::vector<int64_t>(x_dims.size(), -1);
}
PADDLE_ENFORCE_LE(
x_dims.size(),
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
x_dims.size()));
PADDLE_ENFORCE_LE(
repeat_times_data.size(),
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must not be greater than %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
repeat_times_data.size()));
PADDLE_ENFORCE_GE(
repeat_times_data.size(),
1,
errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must be positive integers, but the value received is %d.",
repeat_times_data.size()));
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), repeat_times_data.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = phi::vectorize<int>(x_dims);
if (x_dim_vec.size() > repeat_times_data.size()) {
auto diff = x_dim_vec.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, -1);
} else {
auto diff = repeat_times_data.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
}
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
if (x_dim_vec[i] == -1 || repeat_times_data[i] == -1) {
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
repeat_times_data[i],
0,
errors::InvalidArgument(
"Every element of the input 'repeat_times' for tile op must be "
"greater than 0, but the value given is %d.",
repeat_times_data[i]));
out_shape[i] = x_dim_vec[i] * repeat_times_data[i];
}
}
out->set_dims(phi::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
out->share_lod(x);
}
}
void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out,
......
......@@ -100,6 +100,11 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TileInferMeta(const MetaTensor& x,
const ScalarArray& repeat_times,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tile_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_grad_kernel_impl.h"
PD_REGISTER_KERNEL(tile_grad,
CPU,
ALL_LAYOUT,
phi::TileGradKernel,
bool,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tile_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_kernel_impl.h"
PD_REGISTER_KERNEL(
tile, CPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tile_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_grad_kernel_impl.h"
PD_REGISTER_KERNEL(tile_grad,
GPU,
ALL_LAYOUT,
phi::TileGradKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tile_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_kernel_impl.h"
PD_REGISTER_KERNEL(tile,
GPU,
ALL_LAYOUT,
phi::TileKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/tile_grad_kernel.h"
namespace phi {
template <typename Context, typename T, int Dims>
void TileBackward(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec,
DenseTensor* x_grad) {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
dev_ctx.template Alloc<T>(x_grad);
auto eigen_x_grad = EigenVector<T>::Flatten(*x_grad);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto eigen_out_grad = EigenVector<T>::Flatten(out_grad);
auto& place = *dev_ctx.eigen_device();
funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, eigen_x_grad, eigen_out_grad, reduce_dims, reshape_dims);
}
template <typename T, typename Context>
void TileGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& repeat_times,
DenseTensor* x_grad) {
auto x_dims = x.dims();
auto vec_x_dims = phi::vectorize<int>(x_dims);
auto repeat_times_data = repeat_times.GetData();
if (repeat_times_data.size() < vec_x_dims.size()) {
int diff = vec_x_dims.size() - repeat_times_data.size();
repeat_times_data.insert(repeat_times_data.begin(), diff, 1);
} else {
int diff = repeat_times_data.size() - vec_x_dims.size();
vec_x_dims.insert(vec_x_dims.begin(), diff, 1);
}
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times_data.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times_data[i]);
reshape_dims_vec.push_back(vec_x_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times_data.size(); i++) {
if (repeat_times_data[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
dev_ctx.template Alloc<T>(x_grad);
paddle::framework::TensorCopy(
out_grad, dev_ctx.GetPlace(), dev_ctx, x_grad);
// TensorCopy may change the dims of dx
x_grad->Resize(x_dims);
} else {
PADDLE_ENFORCE_GE(dims,
1,
errors::InvalidArgument(
"Th rank of the input 'Out@GRAD' for tile_grad op "
" must be greater than or equal to 1, but "
"the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for tile_grad op "
"must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 1:
TileBackward<Context, T, 1>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
case 2:
TileBackward<Context, T, 2>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
case 3:
TileBackward<Context, T, 3>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
case 4:
TileBackward<Context, T, 4>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
case 5:
TileBackward<Context, T, 5>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
case 6:
TileBackward<Context, T, 6>(
dev_ctx, out_grad, reshape_dims_vec, reduce_dims_vec, x_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <vector>
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/tile_kernel.h"
namespace phi {
template <typename Context, typename T, int Rank>
void Tile(const Context& dev_ctx,
const DenseTensor& x,
std::vector<int64_t> repeat_times,
DenseTensor* out) {
auto x_dims = x.dims();
for (size_t i = 0; i < repeat_times.size(); ++i) {
PADDLE_ENFORCE_GT(
repeat_times[i],
0,
errors::InvalidArgument(
"All elements of the input 'repeat_times' for tile op must "
"be positive integers, but the value received is %d.",
repeat_times[i]));
}
auto vec_x_dims = phi::vectorize<int>(x_dims);
if (repeat_times.size() < vec_x_dims.size()) {
int diff = vec_x_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_x_dims.size();
vec_x_dims.insert(vec_x_dims.begin(), diff, 1);
}
PADDLE_ENFORCE_EQ(
repeat_times.size(),
vec_x_dims.size(),
errors::InvalidArgument(
"The rank (%d) of the input 'x' and the rank (%d) of the input "
"'repeat_times' for tile op must match after promotion.",
vec_x_dims.size(),
repeat_times.size()));
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
DDim new_x_dims = make_ddim(vec_x_dims);
DDim out_dims(new_x_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out->Resize(out_dims);
auto eigen_x = EigenTensor<T, Rank>::From(x, new_x_dims);
dev_ctx.template Alloc<T>(out);
auto eigen_out = EigenTensor<T, Rank>::From(*out, out_dims);
auto& place = *dev_ctx.eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(eigen_out), To32BitIndex(eigen_x), bcast_dims);
} else {
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, eigen_out, eigen_x, bcast_dims);
}
}
template <typename T, typename Context>
void TileKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& repeat_times,
DenseTensor* out) {
auto rank = x.dims().size();
auto& repeat_times_data = repeat_times.GetData();
int repeat_times_size = repeat_times_data.size();
rank = std::max(rank, repeat_times_size);
switch (rank) {
case 1:
Tile<Context, T, 1>(dev_ctx, x, repeat_times_data, out);
break;
case 2:
Tile<Context, T, 2>(dev_ctx, x, repeat_times_data, out);
break;
case 3:
Tile<Context, T, 3>(dev_ctx, x, repeat_times_data, out);
break;
case 4:
Tile<Context, T, 4>(dev_ctx, x, repeat_times_data, out);
break;
case 5:
Tile<Context, T, 5>(dev_ctx, x, repeat_times_data, out);
break;
case 6:
Tile<Context, T, 6>(dev_ctx, x, repeat_times_data, out);
break;
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#define MAX_RANK_SUPPORTED 6
namespace phi {
template <typename T, typename Context>
void TileGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& repeat_times,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#define MAX_RANK_SUPPORTED 6
namespace phi {
template <typename T, typename Context>
void TileKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& repeat_times,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TileOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
return KernelSignature("tile", {"X"}, {"RepeatTimes"}, {"Out"});
} else if (ctx.InputSize("repeat_times_tensor") > 0) {
return KernelSignature("tile", {"X"}, {"repeat_times_tensor"}, {"Out"});
} else {
return KernelSignature("tile", {"X"}, {"repeat_times"}, {"Out"});
}
}
KernelSignature TileGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("RepeatTimes")) {
return KernelSignature("tile_grad",
{"X", GradVarName("Out")},
{"RepeatTimes"},
{GradVarName("X")});
} else if (ctx.InputSize("repeat_times_tensor") > 0) {
return KernelSignature("tile_grad",
{"X", GradVarName("Out")},
{"repeat_times_tensor"},
{GradVarName("X")});
} else {
return KernelSignature("tile_grad",
{"X", GradVarName("Out")},
{"repeat_times"},
{GradVarName("X")});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(tile, phi::TileOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tile_grad, phi::TileGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册