未验证 提交 e674af23 编写于 作者: B baoachun 提交者: GitHub

refactor the forward implementation of shape npu op (#39613)

上级 c5179772
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
...@@ -27,20 +24,20 @@ template <typename DeviceContext, typename T> ...@@ -27,20 +24,20 @@ template <typename DeviceContext, typename T>
class ShapeNPUKernel : public framework::OpKernel<T> { class ShapeNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_var = ctx.InputVar("Input"); auto* x = ctx.Input<Tensor>("Input");
framework::DDim in_dims;
if (in_var->IsType<pten::SelectedRows>()) {
in_dims = in_var->Get<pten::SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* out_t = ctx.Output<Tensor>("Out"); auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()}); out_t->Resize({x->dims().size()});
// to do: cpuplace? out_t->mutable_data<int32_t>(ctx.GetPlace());
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
for (int i = 0; i < in_dims.size(); ++i) { // The output data type defaults to int32.
out_data[i] = in_dims[i]; auto stream =
} ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner;
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
runner.SetType("Shape").AddInput(*x).AddOutput(*out_t).AddAttr(
"dtype", static_cast<int>(dst_dtype));
runner.Run(stream);
} }
}; };
...@@ -55,5 +52,7 @@ REGISTER_OP_NPU_KERNEL( ...@@ -55,5 +52,7 @@ REGISTER_OP_NPU_KERNEL(
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int8_t>, ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>, ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int64_t>, ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, double>); ops::ShapeNPUKernel<paddle::platform::NPUDeviceContext, double>);
...@@ -51,5 +51,25 @@ class TestShape(OpTest): ...@@ -51,5 +51,25 @@ class TestShape(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestShape_fp16(TestShape):
def init_dtype(self):
self.dtype = np.float16
class TestShape_double(TestShape):
def init_dtype(self):
self.dtype = np.float64
class TestShape_int32(TestShape):
def init_dtype(self):
self.dtype = np.int32
class TestShape_int64(TestShape):
def init_dtype(self):
self.dtype = np.int64
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册