未验证 提交 f014e301 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add int64_t kernels for YoloV3, test=develop (#35045)

* [NPU] add int64 kernels, test=develop

* update ci scripts to be able to trun WITH_ASCEND_INT64 on, test=develop
上级 f13dcfb1
......@@ -224,6 +224,7 @@ option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF)
option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
# PY_VERSION
if(NOT PY_VERSION)
......
......@@ -90,6 +90,10 @@ if(WITH_ASCEND_CL)
add_definitions(-DPADDLE_WITH_ASCEND_CL)
endif()
if(WITH_ASCEND_INT64)
add_definitions(-DPADDLE_WITH_ASCEND_INT64)
endif()
if(WITH_XPU)
message(STATUS "Compile with XPU!")
add_definitions(-DPADDLE_WITH_XPU)
......
......@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/utils.h"
......@@ -22,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class FillConstantNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -32,7 +29,6 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
auto float_value = ctx.Attr<float>("value");
auto* out_var = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -63,25 +59,28 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
}
auto shape = GetShape(ctx);
Tensor tensor_tmp(data_type);
tensor_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, value);
Tensor tensor_value(data_type);
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);
out_var->mutable_data<T>(shape, ctx.GetPlace());
out_var->mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out_var},
{{"dims", framework::vectorize(shape)}});
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("Fill")
.AddInput(framework::vectorize(shape))
.AddInput(tensor_value)
.AddOutput(*out_var)
.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
fill_constant,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
fill_constant, paddle::operators::FillConstantNPUKernel<float>,
paddle::operators::FillConstantNPUKernel<bool>,
paddle::operators::FillConstantNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::FillConstantNPUKernel<int64_t>,
#endif
paddle::operators::FillConstantNPUKernel<paddle::platform::float16>);
......@@ -14,22 +14,11 @@
#include "paddle/fluid/operators/increment_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class IncrementalNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -56,13 +45,11 @@ class IncrementalNPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
increment,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext,
plat::float16>)
increment, paddle::operators::IncrementalNPUKernel<float>,
paddle::operators::IncrementalNPUKernel<double>,
paddle::operators::IncrementalNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::IncrementalNPUKernel<int64_t>,
#endif
paddle::operators::IncrementalNPUKernel<paddle::platform::float16>)
......@@ -18,7 +18,7 @@ limitations under the Licnse. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class MeshgridNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -65,9 +65,12 @@ class MeshgridNPUKernel : public framework::OpKernel<T> {
auto stream =
context.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& runner = NpuOpRunner("BroadcastToD", {reshape_ins_tensor},
{*(outs[i])}, {{"shape", shape}});
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("BroadcastTo")
.AddInput(reshape_ins_tensor)
.AddInput(std::move(shape))
.AddOutput(*(outs[i]))
.Run(stream);
}
}
};
......@@ -75,10 +78,10 @@ class MeshgridNPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
meshgrid, ops::MeshgridNPUKernel<plat::NPUDeviceContext, float>,
ops::MeshgridNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::MeshgridNPUKernel<plat::NPUDeviceContext, int32_t>);
meshgrid, paddle::operators::MeshgridNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::MeshgridNPUKernel<int64_t>,
#endif
paddle::operators::MeshgridNPUKernel<float>,
paddle::operators::MeshgridNPUKernel<paddle::platform::float16>);
......@@ -12,24 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class RangeNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -41,19 +30,19 @@ class RangeNPUKernel : public framework::OpKernel<T> {
framework::Tensor n;
framework::TensorCopy(
*start_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T start = n.data<T>()[0];
framework::TensorCopy(
*end_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T end = n.data<T>()[0];
framework::TensorCopy(
*step_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T step = n.data<T>()[0];
......@@ -78,11 +67,9 @@ class RangeNPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
range, ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, double>)
REGISTER_OP_NPU_KERNEL(range, paddle::operators::RangeNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::RangeNPUKernel<int64_t>,
#endif
paddle::operators::RangeNPUKernel<float>,
paddle::operators::RangeNPUKernel<double>)
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T>
class ScaleNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -40,37 +40,21 @@ class ScaleNPUKernel : public framework::OpKernel<T> {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
}
if (bias_after_scale) {
out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Power", {*x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", bias}});
runner.Run(stream);
} else {
Tensor tmp_x(x->type());
tmp_x.Resize(x->dims());
tmp_x.mutable_data<T>(ctx.GetPlace());
const auto& runner_tmp =
NpuOpRunner("Adds", {*x}, {tmp_x}, {{"value", bias}});
runner_tmp.Run(stream);
out->mutable_data<T>(ctx.GetPlace());
float _bias = 0.0;
const auto& runner =
NpuOpRunner("Power", {tmp_x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", _bias}});
runner.Run(stream);
if (!bias_after_scale) {
bias *= scale;
}
out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Power", {*x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", bias}});
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
scale, ops::ScaleNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ScaleNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
scale, paddle::operators::ScaleNPUKernel<float>,
paddle::operators::ScaleNPUKernel<paddle::platform::float16>);
......@@ -20,7 +20,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
template <typename T>
class StackNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -49,7 +49,7 @@ class StackNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T>
class StackGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -81,15 +81,18 @@ class StackGradNPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
stack, ops::StackNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::StackNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
stack, paddle::operators::StackNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::StackNPUKernel<int64_t>,
#endif
paddle::operators::StackNPUKernel<float>,
paddle::operators::StackNPUKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
stack_grad,
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
stack_grad, paddle::operators::StackNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::StackNPUKernel<int64_t>,
#endif
paddle::operators::StackGradNPUKernel<float>,
paddle::operators::StackGradNPUKernel<paddle::platform::float16>);
......@@ -228,6 +228,7 @@ function cmake_base() {
-DWITH_ARM=${WITH_ARM:-OFF}
-DWITH_ASCEND=${WITH_ASCEND:-OFF}
-DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF}
-DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF}
-DWITH_STRIP=${WITH_STRIP:-ON}
-DON_INFER=${ON_INFER:-OFF}
========================================
......@@ -269,6 +270,7 @@ EOF
-DWITH_ARM=${WITH_ARM:-OFF} \
-DWITH_ASCEND=${WITH_ASCEND:-OFF} \
-DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} \
-DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF} \
-DWITH_STRIP=${WITH_STRIP:-ON} \
-DON_INFER=${ON_INFER:-OFF} \
-DWITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF};build_error=$?
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -72,6 +72,30 @@ class TestFillConstantInt(OpTest):
self.check_output_with_place(self.place)
class TestFillConstantInt64(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {
'shape': [123, 92],
'value': 1,
'dtype': core.VarDesc.VarType.INT64
}
self.outputs = {'Out': np.full((123, 92), 1).astype(self.dtype)}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.int64
def test_check_output(self):
self.check_output_with_place(self.place)
class TestFillConstantFP16(OpTest):
def setUp(self):
self.set_npu()
......
......@@ -81,6 +81,32 @@ class TestIncrementFP16(OpTest):
self.check_output_with_place(self.place)
class TestIncrementINT64(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(NPUPlace)
self.op_type = "increment"
self.init_dtype()
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)),
}
self.pre_input_id = id(self.inputs['X'])
self.attrs = {"Step": 1}
self.outputs = {'Out': np.array([2])}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.int64
def test_check_output(self):
self.check_output_with_place(self.place)
class TestIncrementInplace(unittest.TestCase):
def test_npu(self):
main_prog = paddle.static.Program()
......
......@@ -75,6 +75,16 @@ class TestMeshgridOpFP16(TestMeshgridOp):
return "float16"
class TestMeshgridOpINT32(TestMeshgridOp):
def get_dtype(self):
return "int32"
class TestMeshgridOpINT64(TestMeshgridOp):
def get_dtype(self):
return "int64"
class TestMeshgridOp2(TestMeshgridOp):
def get_x_shape(self):
return [100, 300]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
paddle.enable_static()
class TestRangeOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def setUp(self):
self.set_npu()
self.op_type = "range"
self.init_config()
self.inputs = {
'Start': np.array([self.case[0]]).astype(self.dtype),
'End': np.array([self.case[1]]).astype(self.dtype),
'Step': np.array([self.case[2]]).astype(self.dtype)
}
self.outputs = {
'Out': np.arange(self.case[0], self.case[1],
self.case[2]).astype(self.dtype)
}
def init_config(self):
self.dtype = np.float32
self.case = (0, 1, 0.2)
def test_check_output(self):
self.check_output_with_place(self.place)
class TestFloatRangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.float32
self.case = (0, 5, 1)
class TestInt32RangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (0, 5, 2)
class TestInt32RangeOpCase1(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (10, 1, -2)
class TestInt32RangeOpCase2(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (-1, -10, -2)
class TestInt64RangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (0, 5, 2)
class TestInt64RangeOpCase1(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (10, 1, -2)
class TestInt64RangeOpCase2(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (-1, -10, -2)
if __name__ == "__main__":
unittest.main()
......@@ -72,6 +72,8 @@ class TestStackOpBase(OpTest):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.int32 or self.dtype == np.int64:
return
self.check_grad_with_place(self.place, self.get_x_names(), 'Y')
......@@ -105,6 +107,16 @@ class TestStackOp6(TestStackOpBase):
self.axis = 3
class TestStackOpINT32(TestStackOpBase):
def init_dtype(self):
self.dtype = np.int32
class TestStackOpINT64(TestStackOpBase):
def init_dtype(self):
self.dtype = np.int64
class TestStackAPIWithLoDTensorArray(unittest.TestCase):
"""
Test stack api when the input(x) is a LoDTensorArray.
......@@ -180,15 +192,15 @@ class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase):
class API_test(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64')
data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64')
data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64')
data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float32')
data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float32')
data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float32')
result_stack = paddle.stack([data1, data2, data3], axis=0)
place = paddle.NPUPlace(0)
exe = fluid.Executor(place)
input1 = np.random.random([1, 2]).astype('float64')
input2 = np.random.random([1, 2]).astype('float64')
input3 = np.random.random([1, 2]).astype('float64')
input1 = np.random.random([1, 2]).astype('float32')
input2 = np.random.random([1, 2]).astype('float32')
input3 = np.random.random([1, 2]).astype('float32')
result, = exe.run(
feed={"data1": input1,
"data2": input2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册