未验证 提交 8cabb9f3 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move expand_as kernel to phi (#40373)

* first commit

* fix

* fix

* fix

* fix

* fix

* fix xpu and npu

* fix
上级 42ddee4e
......@@ -121,34 +121,6 @@ REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker,
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp,
ops::ExpandAsV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, double>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
#endif
REGISTER_OP_VERSION(expand_as_v2)
.AddCheckpoint(
......
......@@ -32,219 +32,5 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class ExpandAsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
auto target_shape = context.Attr<std::vector<int>>("target_shape");
auto target_rank = target_shape.size();
PADDLE_ENFORCE_GE(target_rank, rank,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be greater than or equal to "
"the rank (%d) of the input 'x'.",
target_rank, rank));
PADDLE_ENFORCE_GE(rank, 1, platform::errors::InvalidArgument(
"The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
PADDLE_ENFORCE_LE(target_rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be less than or equal to %d.",
target_rank, MAX_RANK_SUPPORTED));
switch (target_rank) {
case 1:
ExpandAs<1>(context);
break;
case 2:
ExpandAs<2>(context);
break;
case 3:
ExpandAs<3>(context);
break;
case 4:
ExpandAs<4>(context);
break;
case 5:
ExpandAs<5>(context);
break;
case 6:
ExpandAs<6>(context);
break;
}
}
protected:
template <int Rank>
void ExpandAs(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto target_shape = context.Attr<std::vector<int>>("target_shape");
auto vec_in_dims = phi::vectorize<int>(in_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument(
"The value of target shape cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
target_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_as_v2 op.",
target_shape[i]));
repeat_times[i] = target_shape[i];
} else if (target_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
target_shape[i], -1,
platform::errors::InvalidArgument(
"When the value in shape is negative for expand_as_v2 op, "
"only -1 is supported, but the value received is %d.",
target_shape[i]));
repeat_times[i] = 1;
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = phi::make_ddim(vec_in_dims);
framework::DDim out_dims = phi::make_ddim(target_shape);
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
};
template <typename DeviceContext, typename T>
class ExpandAsV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto target_shape = context.Attr<std::vector<int>>("target_shape");
auto x_dims = in0->dims();
auto vec_in_dims = phi::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
repeat_times[i] = target_shape[i] / vec_in_dims[i];
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) {
case 1:
ExpandAsBackward<1>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 2:
ExpandAsBackward<2>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 3:
ExpandAsBackward<3>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 4:
ExpandAsBackward<4>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 5:
ExpandAsBackward<5>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 6:
ExpandAsBackward<6>(context, reshape_dims_vec, reduce_dims_vec);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
protected:
template <int Dims>
void ExpandAsBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,7 +51,9 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"reshape",
"reshape_grad",
"expand",
"expand_as",
"expand_grad",
"expand_as_grad",
"sum",
"top_k",
"top_k_grad"});
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_as_grad_kernel.h"
#include "paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(expand_as_grad,
CPU,
ALL_LAYOUT,
phi::ExpandAsGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/impl/expand_as_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(expand_as,
CPU,
ALL_LAYOUT,
phi::ExpandAsKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ExpandAsGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& target_shape,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ExpandAsKernel(const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> y,
const std::vector<int>& target_shape,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_as_grad_kernel.h"
#include "paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(expand_as_grad,
GPU,
ALL_LAYOUT,
phi::ExpandAsGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_as_kernel.h"
#include "paddle/phi/kernels/impl/expand_as_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(expand_as,
GPU,
ALL_LAYOUT,
phi::ExpandAsKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/impl/expand_as_kernel_impl.h"
namespace phi {
template <typename Context, typename T, int Dims>
void ExpandAsBackward(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec,
DenseTensor* in_grad) {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
ctx.template Alloc<T>(in_grad);
auto x_grad = EigenVector<T>::Flatten(*in_grad);
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad0 = EigenVector<T>::Flatten(out_grad);
auto& place = *ctx.eigen_device();
funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad0, reduce_dims, reshape_dims);
}
template <typename T, typename Context>
void ExpandAsGradKernel(const Context& context,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& target_shape,
DenseTensor* in_grad) {
auto x_dims = x.dims();
auto vec_in_dims = phi::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
repeat_times[i] = target_shape[i] / vec_in_dims[i];
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
context.template Alloc<T>(in_grad);
phi::Copy(context, out_grad, context.GetPlace(), false, in_grad);
} else {
PADDLE_ENFORCE_GE(
dims,
1,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 1:
ExpandAsBackward<Context, T, 1>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandAsBackward<Context, T, 2>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandAsBackward<Context, T, 3>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandAsBackward<Context, T, 4>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandAsBackward<Context, T, 5>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandAsBackward<Context, T, 6>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
namespace phi {
template <typename Context, typename T, int Rank>
void ExpandAs(const Context& context,
const DenseTensor& x,
const std::vector<int>& target_shape,
DenseTensor* out) {
auto in_dims = x.dims();
auto vec_in_dims = phi::vectorize<int>(in_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(
target_shape[i],
0,
errors::InvalidArgument("The value of target shape cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
target_shape[i],
0,
errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_as_v2 op.",
target_shape[i]));
repeat_times[i] = target_shape[i];
} else if (target_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i],
target_shape[i],
errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_as_v2 op.",
vec_in_dims[i],
target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
target_shape[i],
-1,
errors::InvalidArgument(
"When the value in shape is negative for expand_as_v2 op, "
"only -1 is supported, but the value received is %d.",
target_shape[i]));
repeat_times[i] = 1;
}
}
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
phi::DDim new_in_dims = phi::make_ddim(vec_in_dims);
phi::DDim out_dims = phi::make_ddim(target_shape);
out->Resize(out_dims);
context.template Alloc<T>(out);
auto x0 = EigenTensor<T, Rank>::From(x, new_in_dims);
auto y = EigenTensor<T, Rank>::From(*out, out_dims);
auto& place = *context.eigen_device();
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, y, x0, bcast_dims);
}
template <typename T, typename Context>
void ExpandAsKernel(const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> y,
const std::vector<int>& target_shape,
DenseTensor* out) {
auto rank = x.dims().size();
auto target_rank = target_shape.size();
PADDLE_ENFORCE_GE(target_rank,
rank,
errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be greater than or equal to "
"the rank (%d) of the input 'x'.",
target_rank,
rank));
PADDLE_ENFORCE_GE(
rank,
1,
errors::InvalidArgument("The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
PADDLE_ENFORCE_LE(target_rank,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be less than or equal to %d.",
target_rank,
MAX_RANK_SUPPORTED));
switch (target_rank) {
case 1:
ExpandAs<Context, T, 1>(ctx, x, target_shape, out);
break;
case 2:
ExpandAs<Context, T, 2>(ctx, x, target_shape, out);
break;
case 3:
ExpandAs<Context, T, 3>(ctx, x, target_shape, out);
break;
case 4:
ExpandAs<Context, T, 4>(ctx, x, target_shape, out);
break;
case 5:
ExpandAs<Context, T, 5>(ctx, x, target_shape, out);
break;
case 6:
ExpandAs<Context, T, 6>(ctx, x, target_shape, out);
break;
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as", {"X", "Y"}, {"target_shape"}, {"Out"});
}
KernelSignature ExpandAsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as_grad",
{"X", GradVarName("Out")},
{"target_shape"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(expand_as_v2, expand_as);
PD_REGISTER_BASE_KERNEL_NAME(expand_as_v2_grad, expand_as_grad);
PD_REGISTER_ARG_MAPPING_FN(expand_as_v2, phi::ExpandAsOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(expand_as_v2_grad,
phi::ExpandAsGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册