未验证 提交 2d16d69b 编写于 作者: L Linjie Chen 提交者: GitHub

[Pten]Move expand_v2 to pten (#39471)

* move expand to pten

* move expand_v2 to pten

* move expand_v2 to pten

* fix grad register

* fix grad register

* fix tensorcpry

* fix tensorcopy

* fix tensorcopy

* fix tensorcopy

* fix tensorcopy

* fix ci

* fix tensorcopy
上级 ab866777
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#define MAX_RANK_SUPPORTED 6
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -296,33 +299,3 @@ REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp, ...@@ -296,33 +299,3 @@ REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>, ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>, ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ExpandV2GradNoNeedBufVarsInferer); ops::ExpandV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_v2_grad,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandV2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
expand_v2, ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandV2Kernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_v2_grad,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandV2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
#endif
...@@ -91,259 +91,5 @@ inline std::vector<int> get_expand_shape( ...@@ -91,259 +91,5 @@ inline std::vector<int> get_expand_shape(
return ctx.Attr<std::vector<int>>("shape"); return ctx.Attr<std::vector<int>>("shape");
} }
} }
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::To32BitIndex;
template <typename DeviceContext, typename T>
class ExpandV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
PADDLE_ENFORCE_GE(
rank, 1,
platform::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be positive, "
"but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be less than "
"or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
auto expand_shape = get_expand_shape(context);
auto shape_size = expand_shape.size();
PADDLE_ENFORCE_GE(
shape_size, rank,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"greater than or equal to the rank (%d) of the input 'X'.",
shape_size, rank));
PADDLE_ENFORCE_LE(
shape_size, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"less than or equal to %d.",
shape_size, MAX_RANK_SUPPORTED));
rank = std::max(rank, static_cast<int>(shape_size));
switch (rank) {
case 1:
Expand<1>(context);
break;
case 2:
Expand<2>(context);
break;
case 3:
Expand<3>(context);
break;
case 4:
Expand<4>(context);
break;
case 5:
Expand<5>(context);
break;
case 6:
Expand<6>(context);
break;
}
}
protected:
template <int Rank>
void Expand(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto expand_shape = get_expand_shape(context);
auto vec_in_dims = framework::vectorize<int>(in_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(expand_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_v2 op.",
expand_shape[i]));
repeat_times[i] = expand_shape[i];
} else if (expand_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], expand_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_v2 op.",
vec_in_dims[i], expand_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
expand_shape[i], -1,
platform::errors::InvalidArgument(
"When the value in shape is negative for expand_v2 op, "
"only -1 is supported, but the value received is %d.",
expand_shape[i]));
repeat_times[i] = 1;
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims(new_in_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x), bcast_dims);
} else {
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
template <typename DeviceContext, typename T>
class ExpandV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto expand_shape = get_expand_shape(context);
auto x_dims = in0->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
if (expand_shape[i] < 0) {
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i] / vec_in_dims[i];
}
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) {
case 1:
ExpandBackward<1>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 2:
ExpandBackward<2>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 3:
ExpandBackward<3>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 4:
ExpandBackward<4>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 5:
ExpandBackward<5>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 6:
ExpandBackward<6>(context, reshape_dims_vec, reduce_dims_vec);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
protected:
template <int Dims>
void ExpandBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/expand_v2_op.h" #include "paddle/fluid/operators/expand_v2_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/expand_v2_op.h" #include "paddle/fluid/operators/expand_v2_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -45,6 +45,8 @@ const std::unordered_set<std::string> deprecated_op_names({"flatten", ...@@ -45,6 +45,8 @@ const std::unordered_set<std::string> deprecated_op_names({"flatten",
"mean", "mean",
"reshape", "reshape",
"reshape_grad", "reshape_grad",
"expand",
"expand_grad",
"sum"}); "sum"});
class DefaultKernelSignatureMap { class DefaultKernelSignatureMap {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/expand_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/expand_grad_kernel_impl.h"
PT_REGISTER_KERNEL(expand_grad,
CPU,
ALL_LAYOUT,
pten::ExpandGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/expand_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/expand_kernel_impl.h"
PT_REGISTER_KERNEL(expand,
CPU,
ALL_LAYOUT,
pten::ExpandKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/device_context.h"
namespace pten {
template <typename T, typename Context>
void ExpandGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& shape,
DenseTensor* in_grad);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/device_context.h"
namespace pten {
template <typename T, typename Context>
void ExpandKernel(const Context& ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
} // namepsace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/expand_grad_kernel.h"
#include "paddle/pten/kernels/impl/expand_grad_kernel_impl.h"
PT_REGISTER_KERNEL(expand_grad,
GPU,
ALL_LAYOUT,
pten::ExpandGradKernel,
float,
double,
paddle::platform::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/expand_kernel.h"
#include "paddle/pten/kernels/impl/expand_kernel_impl.h"
PT_REGISTER_KERNEL(expand,
GPU,
ALL_LAYOUT,
pten::ExpandKernel,
float,
double,
paddle::platform::float16,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/impl/expand_kernel_impl.h"
namespace pten {
template <typename Context, typename T, int Dims>
void ExpandBackward(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec,
DenseTensor* in_grad) {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
ctx.template Alloc<T>(in_grad);
in_grad->data<T>();
auto x_grad = EigenVector<T>::Flatten(*in_grad);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad0 = EigenVector<T>::Flatten(out_grad);
auto& place = *ctx.eigen_device();
pten::funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad0, reduce_dims, reshape_dims);
}
template <typename T, typename Context>
void ExpandGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const ScalarArray& shape,
DenseTensor* in_grad) {
auto expand_shape = shape.GetData();
auto x_dims = x.dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
// 1. reshape_dims_vec is the broadcast parameter.
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
if (expand_shape[i] < 0) {
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i] / vec_in_dims[i];
}
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
pten::Copy(ctx, out_grad, false, in_grad);
} else {
PADDLE_ENFORCE_GE(dims,
1,
pten::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
pten::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 1:
ExpandBackward<Context, T, 1>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandBackward<Context, T, 2>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandBackward<Context, T, 3>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandBackward<Context, T, 4>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandBackward<Context, T, 5>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandBackward<Context, T, 6>(
ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(pten::errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
namespace pten {
using Tensor = DenseTensor;
template <typename Context, typename T, int Rank>
void Expand(const Context& ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
auto in_dims = x.dims();
auto expand_shape = shape.GetData();
auto vec_in_dims = framework::vectorize<int>(in_dims);
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(
expand_shape[i],
0,
pten::errors::InvalidArgument("The expanded size cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
expand_shape[i],
0,
pten::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_v2 op.",
expand_shape[i]));
repeat_times[i] = expand_shape[i];
} else if (expand_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i],
expand_shape[i],
pten::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_v2 op.",
vec_in_dims[i],
expand_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = expand_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
expand_shape[i],
-1,
pten::errors::InvalidArgument(
"When the value in shape is negative for expand_v2 op, "
"only -1 is supported, but the value received is %d.",
expand_shape[i]));
repeat_times[i] = 1;
}
}
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims(new_in_dims);
for (size_t i = 0; i < repeat_times.size(); ++i) {
out_dims[i] *= repeat_times[i];
}
out->Resize(out_dims);
auto x0 = EigenTensor<T, Rank>::From(x, new_in_dims);
ctx.template Alloc<T>(out);
out->data<T>();
auto y = EigenTensor<T, Rank>::From(*out, out_dims);
auto& place = *ctx.eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
pten::funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x0), bcast_dims);
} else {
pten::funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, y, x0, bcast_dims);
}
}
template <typename T, typename Context>
void ExpandKernel(const Context& ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
auto rank = x.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
pten::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be positive, "
"but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
MAX_RANK_SUPPORTED,
pten::errors::InvalidArgument(
"The rank of the input 'X' for expand_v2 op must be less than "
"or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
rank));
auto expand_shape = shape.GetData();
auto shape_size = expand_shape.size();
PADDLE_ENFORCE_GE(
shape_size,
rank,
pten::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"greater than or equal to the rank (%d) of the input 'X'.",
shape_size,
rank));
PADDLE_ENFORCE_LE(
shape_size,
MAX_RANK_SUPPORTED,
pten::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand_v2 op must be "
"less than or equal to %d.",
shape_size,
MAX_RANK_SUPPORTED));
rank = std::max(rank, static_cast<int>(shape_size));
switch (rank) {
case 1:
Expand<Context, T, 1>(ctx, x, shape, out);
break;
case 2:
Expand<Context, T, 2>(ctx, x, shape, out);
break;
case 3:
Expand<Context, T, 3>(ctx, x, shape, out);
break;
case 4:
Expand<Context, T, 4>(ctx, x, shape, out);
break;
case 5:
Expand<Context, T, 5>(ctx, x, shape, out);
break;
case 6:
Expand<Context, T, 6>(ctx, x, shape, out);
break;
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Shape")) {
return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand", {"X"}, {"expand_shapes_tensor"}, {"Out"});
} else {
return KernelSignature("expand", {"X"}, {"shape"}, {"Out"});
}
}
KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"Shape"},
{GradVarName("X")});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"expand_shapes_tensor"},
{GradVarName("X")});
} else {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"shape"},
{GradVarName("X")});
}
}
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(expand_v2, expand);
PT_REGISTER_BASE_KERNEL_NAME(expand_v2_grad, expand_grad);
PT_REGISTER_ARG_MAPPING_FN(expand_v2, pten::ExpandOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(expand_v2_grad, pten::ExpandGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册