未验证 提交 31b1f707 编写于 作者: B baoachun 提交者: GitHub

refactor the forward implementation of reshape npu op (#38748)

* refactor the forward implementation of reshape npu op

* update reshape npu op

* update reshape npu op
上级 7d4ce5b3
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
...@@ -25,23 +26,93 @@ template <typename DeviceContext, typename T> ...@@ -25,23 +26,93 @@ template <typename DeviceContext, typename T>
class Reshape2NPUKernel : public framework::OpKernel<T> { class Reshape2NPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto place = ctx.GetPlace();
auto* x = ctx.Input<framework::Tensor>("X"); auto* x = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensor"); std::vector<int32_t> target_shape_vector;
if (list_new_shape_tensor.size() > 0) { auto shape_tensor_vector = ctx.MultiInput<framework::Tensor>("ShapeTensor");
PADDLE_THROW(platform::errors::Unimplemented( if (shape_tensor_vector.size() > 0) {
"Input(ShapeTensor) is not supported on NPU.")); for (auto* shape_tensor : shape_tensor_vector) {
PADDLE_ENFORCE_EQ(
shape_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"If the element type of 'shape' in Reshape Op is Tensor, "
"the element's shape must be [1]. But received the element's "
"shape is [%d]",
shape_tensor->dims().size()));
target_shape_vector.push_back(GetDataFromTensor<int>(shape_tensor)[0]);
}
} else {
auto* shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
if (shape_tensor) {
target_shape_vector = GetDataFromTensor<int>(shape_tensor);
} else {
target_shape_vector = ctx.Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(
target_shape_vector.size(), 0,
platform::errors::InvalidArgument(
"The length of shape attribute should be larger than 0 when "
"input ShapeTensor and Shape are empty!"));
}
} }
PADDLE_ENFORCE_EQ(ctx.Input<framework::LoDTensor>("Shape"), nullptr,
platform::errors::Unimplemented( int num_negative =
"Input(Shape) is not supported on NPU.")); std::count(target_shape_vector.begin(), target_shape_vector.end(), -1);
auto shape = out->dims(); PADDLE_ENFORCE_LE(
out->mutable_data(ctx.GetPlace(), x->type()); num_negative, 1,
framework::TensorCopy( platform::errors::InvalidArgument(
*x, ctx.GetPlace(), "The max number of -1 in shape attribute or shape tensor is 1 "
ctx.template device_context<platform::DeviceContext>(), out); "but received %d.",
out->Resize(shape); num_negative));
auto it_zero =
std::find(target_shape_vector.begin(), target_shape_vector.end(), 0);
if (it_zero != target_shape_vector.end()) {
int x_rank = x->dims().size();
for (size_t i = 0; i < target_shape_vector.size(); i++) {
if (target_shape_vector[i] == 0) {
PADDLE_ENFORCE_LT(
i, x_rank,
platform::errors::InvalidArgument(
"The index of 0 in shape attribute or shape tensor",
"should be less than input dim size, ",
"but the index is %d and input dim size is %d", i, x_rank));
target_shape_vector[i] = x->dims().at(i);
}
}
}
auto it =
std::find(target_shape_vector.begin(), target_shape_vector.end(), -1);
if (it != target_shape_vector.end()) {
auto ddim_out_vec = framework::vectorize(x->dims());
int ddim_out_product = std::accumulate(
ddim_out_vec.begin(), ddim_out_vec.end(), 1, std::multiplies<int>());
int reshape_out_product = std::accumulate(target_shape_vector.begin(),
target_shape_vector.end(), -1,
std::multiplies<int>());
int index = std::distance(target_shape_vector.begin(), it);
target_shape_vector[index] = ddim_out_product / reshape_out_product;
}
auto out_dims = framework::make_ddim(target_shape_vector);
out->mutable_data<T>(out_dims, place);
NpuOpRunner runner;
// the shape input must be on the host side
runner.SetType("Reshape")
.AddInput(*x)
.AddInput(std::vector<int32_t>(target_shape_vector))
.AddOutput(*out)
.AddAttr("axis", 0)
.AddAttr("num_axes", -1);
runner.Run(stream);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册