未验证 提交 e674af23 编写于 作者: B baoachun 提交者: GitHub

refactor the forward implementation of shape npu op (#39613)

上级 c5179772
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......@@ -27,20 +24,20 @@ template <typename DeviceContext, typename T>
class ShapeNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_var = ctx.InputVar("Input");
framework::DDim in_dims;
if (in_var->IsType<pten::SelectedRows>()) {
in_dims = in_var->Get<pten::SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* x = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()});
// to do: cpuplace?
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
out_t->Resize({x->dims().size()});
out_t->mutable_data<int32_t>(ctx.GetPlace());
// The output data type defaults to int32.
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner;
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
runner.SetType("Shape").AddInput(*x).AddOutput(*out_t).AddAttr(
"dtype", static_cast<int>(dst_dtype));
runner.Run(stream);
}
};
......@@ -55,5 +52,7 @@ REGISTER_OP_NPU_KERNEL(
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, double>);
......@@ -51,5 +51,25 @@ class TestShape(OpTest):
self.check_output_with_place(self.place)
class TestShape_fp16(TestShape):
def init_dtype(self):
self.dtype = np.float16
class TestShape_double(TestShape):
def init_dtype(self):
self.dtype = np.float64
class TestShape_int32(TestShape):
def init_dtype(self):
self.dtype = np.int32
class TestShape_int64(TestShape):
def init_dtype(self):
self.dtype = np.int64
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册