未验证 提交 c038cc7a 编写于 作者: R ronnywang 提交者: GitHub

[NPU] Add int64 supporting for expand_v2, reduce_max, scale and tests (#36582)

* add TypeAdapter method for npu_op_runner

* add int64 supporting for elementwise_mul and reduce_sum

* add int64 supporting and UT for expand_v2, scale and reduce_max

* fix bug
上级 adb28d67
...@@ -503,7 +503,6 @@ class SwishGradNPUKernel : public framework::OpKernel<T> { ...@@ -503,7 +503,6 @@ class SwishGradNPUKernel : public framework::OpKernel<T> {
beta_x.mutable_data<T>(x->dims(), ctx.GetPlace()); beta_x.mutable_data<T>(x->dims(), ctx.GetPlace());
sigmoid_out.mutable_data<T>(x->dims(), ctx.GetPlace()); sigmoid_out.mutable_data<T>(x->dims(), ctx.GetPlace());
swish_out.mutable_data<T>(x->dims(), ctx.GetPlace()); swish_out.mutable_data<T>(x->dims(), ctx.GetPlace());
const auto& muls_runner = const auto& muls_runner =
NpuOpRunner("Muls", {*x}, {beta_x}, {{"value", beta}}); NpuOpRunner("Muls", {*x}, {beta_x}, {{"value", beta}});
muls_runner.Run(stream); muls_runner.Run(stream);
...@@ -515,6 +514,9 @@ class SwishGradNPUKernel : public framework::OpKernel<T> { ...@@ -515,6 +514,9 @@ class SwishGradNPUKernel : public framework::OpKernel<T> {
const auto& mul_runner = const auto& mul_runner =
NpuOpRunner("Mul", {sigmoid_out, *x}, {swish_out}, {}); NpuOpRunner("Mul", {sigmoid_out, *x}, {swish_out}, {});
mul_runner.Run(stream); mul_runner.Run(stream);
const auto& muls_runner2 =
NpuOpRunner("Muls", {swish_out}, {swish_out}, {{"value", beta}});
muls_runner2.Run(stream);
const auto& mul_runner1 = const auto& mul_runner1 =
NpuOpRunner("Mul", {sigmoid_out, swish_out}, {*dx}, {}); NpuOpRunner("Mul", {sigmoid_out, swish_out}, {*dx}, {});
......
...@@ -143,8 +143,16 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> { ...@@ -143,8 +143,16 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(elementwise_mul, ops::ElementwiseMulNPUKernel<float>, REGISTER_OP_NPU_KERNEL(elementwise_mul, ops::ElementwiseMulNPUKernel<float>,
ops::ElementwiseMulNPUKernel<paddle::platform::float16>); ops::ElementwiseMulNPUKernel<paddle::platform::float16>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseMulNPUKernel<int64_t>,
#endif
ops::ElementwiseMulNPUKernel<int>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
elementwise_mul_grad, ops::ElementwiseMulGradNPUKernel<float>, elementwise_mul_grad, ops::ElementwiseMulGradNPUKernel<float>,
ops::ElementwiseMulGradNPUKernel<paddle::platform::float16>); ops::ElementwiseMulGradNPUKernel<paddle::platform::float16>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseMulGradNPUKernel<int64_t>,
#endif
ops::ElementwiseMulGradNPUKernel<int>);
...@@ -106,11 +106,28 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> { ...@@ -106,11 +106,28 @@ class ExpandV2NPUKernel : public framework::OpKernel<T> {
Out->Resize(out_dims); Out->Resize(out_dims);
Out->mutable_data<T>(ctx.GetPlace()); Out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("ExpandD", {*X}, {*Out}, attr_input); const auto& dev_ctx =
auto stream = ctx.template device_context<paddle::platform::NPUDeviceContext>();
ctx.template device_context<paddle::platform::NPUDeviceContext>() auto op_func = [](const std::vector<Tensor>& inputs,
.stream(); const std::vector<Tensor>& outputs,
runner.Run(stream); const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& runner = NpuOpRunner("ExpandD", inputs, outputs, attrs);
runner.Run(dev_ctx.stream());
};
if (X->type() == framework::proto::VarType::BOOL) {
NpuOpRunner::TypeAdapter({*X}, {*Out}, attr_input, dev_ctx, op_func,
{framework::proto::VarType::UINT8},
{framework::proto::VarType::UINT8});
} else if (X->type() == framework::proto::VarType::INT64) {
NpuOpRunner::TypeAdapter({*X}, {*Out}, attr_input, dev_ctx, op_func,
{framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else {
const auto& runner = NpuOpRunner("ExpandD", {*X}, {*Out}, attr_input);
runner.Run(dev_ctx.stream());
}
} }
}; };
...@@ -181,7 +198,9 @@ REGISTER_OP_NPU_KERNEL( ...@@ -181,7 +198,9 @@ REGISTER_OP_NPU_KERNEL(
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, int>); ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ExpandV2NPUKernel<paddle::platform::NPUDeviceContext, bool>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
expand_v2_grad, expand_v2_grad,
......
...@@ -22,13 +22,13 @@ namespace operators { ...@@ -22,13 +22,13 @@ namespace operators {
template <typename T> template <typename T>
class FillConstantNPUKernel : public framework::OpKernel<T> { class FillConstantNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto data_type = auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto str_value = ctx.Attr<std::string>("str_value"); auto str_value = ctx.Attr<std::string>("str_value");
auto float_value = ctx.Attr<float>("value"); auto float_value = ctx.Attr<float>("value");
auto* out_var = ctx.Output<framework::Tensor>("Out"); auto *out_var = ctx.Output<framework::Tensor>("Out");
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -59,28 +59,49 @@ class FillConstantNPUKernel : public framework::OpKernel<T> { ...@@ -59,28 +59,49 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
} }
auto shape = GetShape(ctx); auto shape = GetShape(ctx);
Tensor tensor_value(data_type);
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);
out_var->mutable_data<T>(shape, ctx.GetPlace()); out_var->mutable_data<T>(shape, ctx.GetPlace());
if (data_type != framework::proto::VarType::BOOL) {
NpuOpRunner runner; Tensor tensor_value(data_type);
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);
NpuOpRunner runner;
#if (CANN_VERSION_CODE >= 503003) #if (CANN_VERSION_CODE >= 503003)
runner.SetType("FillD") runner.SetType("FillD")
.AddInput(tensor_value) .AddInput(tensor_value)
.AddOutput(*out_var) .AddOutput(*out_var)
.AddAttrs( .AddAttrs(
{{ "dims", {{ "dims",
framework::vectorize(shape) }}) framework::vectorize(shape) }})
.Run(stream); .Run(stream);
#else #else
runner.SetType("Fill") runner.SetType("Fill")
.AddInput(framework::vectorize(shape)) .AddInput(framework::vectorize(shape))
.AddInput(tensor_value) .AddInput(tensor_value)
.AddOutput(*out_var) .AddOutput(*out_var)
.Run(stream); .Run(stream);
#endif #endif
} else {
const auto &dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto op_func = [&shape, &value](
const std::vector<Tensor> &inputs, const std::vector<Tensor> &outputs,
const NPUAttributeMap &attrs,
const platform::NPUDeviceContext &dev_ctx) {
Tensor tensor_value;
tensor_value.mutable_data<uint8_t>({1}, dev_ctx.GetPlace());
FillNpuTensorWithConstant<uint8_t>(&tensor_value,
static_cast<uint8_t>(value));
NpuOpRunner runner;
runner.SetType("Fill")
.AddInput(framework::vectorize(shape))
.AddInput(tensor_value)
.AddOutput(outputs[0])
.Run(dev_ctx.stream());
};
NpuOpRunner::TypeAdapter({}, {*out_var}, {}, dev_ctx, op_func, {},
{framework::proto::VarType::UINT8});
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -436,5 +436,67 @@ void NpuOpRunner::Run(aclrtStream stream) const { ...@@ -436,5 +436,67 @@ void NpuOpRunner::Run(aclrtStream stream) const {
PADDLE_ENFORCE_NPU_SUCCESS(ret); PADDLE_ENFORCE_NPU_SUCCESS(ret);
} }
void NpuOpRunner::TypeAdapter(
const std::vector<Tensor> &inputs, const std::vector<Tensor> &outputs,
const NPUAttributeMap &attrs, const platform::NPUDeviceContext &dev_ctx,
std::function<void(const std::vector<Tensor> &, const std::vector<Tensor> &,
const NPUAttributeMap &,
const platform::NPUDeviceContext &)>
op_runner,
const std::vector<framework::proto::VarType::Type> &input_type,
const std::vector<framework::proto::VarType::Type> &output_type) {
PADDLE_ENFORCE_EQ(
inputs.size(), input_type.size(),
platform::errors::InvalidArgument(
"The number of inputs must be equal to input_type.size()."));
PADDLE_ENFORCE_EQ(
outputs.size(), output_type.size(),
platform::errors::InvalidArgument(
"The number of outputs must be equal to output_type.size()."));
std::vector<Tensor> tmp_inputs(inputs.size());
std::vector<Tensor> tmp_outputs(outputs.size());
for (size_t i = 0; i < input_type.size(); ++i) {
bool cast_input =
(input_type[i] == -1 || input_type[i] != inputs[i].type());
if (!cast_input) {
tmp_inputs[i].ShareDataWith(inputs[i]);
} else {
tmp_inputs[i].Resize(inputs[i].dims());
tmp_inputs[i].mutable_data(dev_ctx.GetPlace(), input_type[i]);
const auto &cast_runner = NpuOpRunner(
"Cast", {inputs[i]}, {tmp_inputs[i]},
{{"dst_type", static_cast<int>(ConvertToNpuDtype(input_type[i]))}});
cast_runner.Run(dev_ctx.stream());
}
}
for (size_t i = 0; i < output_type.size(); ++i) {
bool cast_output =
(output_type[i] == -1 || output_type[i] != outputs[i].type());
if (!cast_output) {
tmp_outputs[i].ShareDataWith(outputs[i]);
} else {
tmp_outputs[i].Resize(outputs[i].dims());
tmp_outputs[i].mutable_data(dev_ctx.GetPlace(), output_type[i]);
}
}
op_runner(tmp_inputs, tmp_outputs, attrs, dev_ctx);
for (size_t i = 0; i < output_type.size(); ++i) {
bool cast_output =
(output_type[i] == -1 || output_type[i] != outputs[i].type());
if (cast_output) {
const auto &cast_runner = NpuOpRunner(
"Cast", {tmp_outputs[i]}, {outputs[i]},
{{"dst_type",
static_cast<int>(ConvertToNpuDtype(outputs[i].type()))}});
cast_runner.Run(dev_ctx.stream());
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -103,6 +103,16 @@ class NpuOpRunner { ...@@ -103,6 +103,16 @@ class NpuOpRunner {
void Run(aclrtStream stream = nullptr) const; void Run(aclrtStream stream = nullptr) const;
static void TypeAdapter(
const std::vector<Tensor> &inputs, const std::vector<Tensor> &outputs,
const NPUAttributeMap &attrs, const platform::NPUDeviceContext &dev_ctx,
std::function<void(const std::vector<Tensor> &,
const std::vector<Tensor> &, const NPUAttributeMap &,
const platform::NPUDeviceContext &)>
op_runner,
const std::vector<framework::proto::VarType::Type> &input_type,
const std::vector<framework::proto::VarType::Type> &output_type);
private: private:
aclTensorDesc *CreateTensorDesc(Tensor tensor, aclTensorDesc *CreateTensorDesc(Tensor tensor,
aclMemType mem_type = ACL_MEMTYPE_DEVICE); aclMemType mem_type = ACL_MEMTYPE_DEVICE);
......
...@@ -73,20 +73,33 @@ class ReduceMaxNPUKernel : public framework::OpKernel<T> { ...@@ -73,20 +73,33 @@ class ReduceMaxNPUKernel : public framework::OpKernel<T> {
attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}}; attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}};
} }
auto stream = const auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>();
.stream(); if (x->type() == framework::proto::VarType::INT64) {
auto op_func = [](const std::vector<Tensor>& inputs,
const auto& runner = const std::vector<Tensor>& outputs,
NpuOpRunner("ReduceMaxD", {*x}, {cast_out}, attr_input); const NPUAttributeMap& attrs,
runner.Run(stream); const platform::NPUDeviceContext& dev_ctx) {
const auto& runner =
NpuOpRunner("ReduceMaxD", {inputs[0]}, {outputs[0]}, attrs);
runner.Run(dev_ctx.stream());
};
NpuOpRunner::TypeAdapter({*x}, {cast_out}, attr_input, dev_ctx, op_func,
{framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else {
const auto& runner =
NpuOpRunner("ReduceMaxD", {*x}, {cast_out}, attr_input);
runner.Run(dev_ctx.stream());
}
if (x->type() != cast_out_dtype) { if (x->type() != cast_out_dtype) {
auto dst_dtype = ConvertToNpuDtype(cast_out_dtype); auto dst_dtype = ConvertToNpuDtype(cast_out_dtype);
const auto& runner_cast = const auto& runner_cast =
NpuOpRunner("Cast", {cast_out}, {*out}, NpuOpRunner("Cast", {cast_out}, {*out},
{{"dst_type", static_cast<int>(dst_dtype)}}); {{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast.Run(stream); runner_cast.Run(dev_ctx.stream());
} }
} }
}; };
...@@ -98,4 +111,6 @@ namespace ops = paddle::operators; ...@@ -98,4 +111,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
reduce_max, ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, float>, reduce_max, ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, float>,
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, plat::float16>); ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int>);
...@@ -142,12 +142,18 @@ namespace ops = paddle::operators; ...@@ -142,12 +142,18 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
reduce_sum, reduce_sum,
ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, float>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
#endif
ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, int>, ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext, ops::ReduceSumNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
reduce_sum_grad, reduce_sum_grad,
ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
#endif
ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, int>, ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext, ops::ReduceSumGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
...@@ -37,15 +37,47 @@ class ScaleNPUKernel : public framework::OpKernel<T> { ...@@ -37,15 +37,47 @@ class ScaleNPUKernel : public framework::OpKernel<T> {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor"); auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor)); scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
} }
if (isinf(scale)) {
if (signbit(scale)) {
scale = -std::numeric_limits<float>::max();
} else {
scale = std::numeric_limits<float>::max();
}
}
if (!bias_after_scale) { if (!bias_after_scale) {
bias *= scale; bias *= scale;
} }
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Power", {*x}, {*out}, framework::NPUAttributeMap attrs = {
{{"power", power}, {"scale", scale}, {"shift", bias}}); {"power", power}, {"scale", scale}, {"shift", bias}};
runner.Run(stream); const auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto op_func = [](const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs,
const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& muls_runner = NpuOpRunner("Muls", {inputs[0]}, {outputs[0]},
{{"value", attrs.at("scale")}});
muls_runner.Run(dev_ctx.stream());
const auto& adds_runner = NpuOpRunner("Adds", {outputs[0]}, {outputs[0]},
{{"value", attrs.at("shift")}});
adds_runner.Run(dev_ctx.stream());
};
if (x->type() == framework::proto::VarType::INT32) {
NpuOpRunner::TypeAdapter({*x}, {*out}, attrs, dev_ctx, op_func,
{framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else if (x->type() == framework::proto::VarType::INT64) {
NpuOpRunner::TypeAdapter({*x}, {*out}, attrs, dev_ctx, op_func,
{framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else {
const auto& runner = NpuOpRunner("Power", {*x}, {*out}, attrs);
runner.Run(stream);
}
} }
}; };
...@@ -54,4 +86,6 @@ class ScaleNPUKernel : public framework::OpKernel<T> { ...@@ -54,4 +86,6 @@ class ScaleNPUKernel : public framework::OpKernel<T> {
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
scale, paddle::operators::ScaleNPUKernel<float>, scale, paddle::operators::ScaleNPUKernel<float>,
paddle::operators::ScaleNPUKernel<paddle::platform::float16>); paddle::operators::ScaleNPUKernel<paddle::platform::float16>,
paddle::operators::ScaleNPUKernel<int64_t>,
paddle::operators::ScaleNPUKernel<int>);
...@@ -201,13 +201,16 @@ class TestExpandV2OpFloat(OpTest): ...@@ -201,13 +201,16 @@ class TestExpandV2OpFloat(OpTest):
# Situation 5: input x is int32 # Situation 5: input x is int32
# skip grad check for int32 # skip grad check for int32
class TestExpandV2OpInteger(OpTest): class TestExpandV2OpInteger(OpTest):
def init_dtype(self):
self.dtype = 'int32'
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
self.op_type = "expand_v2" self.op_type = "expand_v2"
self.inputs = { self.inputs = {
'X': np.random.randint( 'X': np.random.randint(
10, size=(2, 4, 20)).astype("int32") 10, size=(2, 4, 20)).astype(self.dtype)
} }
self.attrs = {'shape': [2, 4, 20]} self.attrs = {'shape': [2, 4, 20]}
output = np.tile(self.inputs['X'], (1, 1, 1)) output = np.tile(self.inputs['X'], (1, 1, 1))
...@@ -221,6 +224,25 @@ class TestExpandV2OpInteger(OpTest): ...@@ -221,6 +224,25 @@ class TestExpandV2OpInteger(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TesstExpandV2OpInt64(TestExpandV2OpInteger):
def init_dtype(self):
self.dtype = 'int64'
class TesstExpandV2OpBool(TestExpandV2OpInteger):
def init_dtype(self):
self.dtype = 'bool'
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "expand_v2"
self.inputs = {'X': np.random.randint(10, size=(2, 4, 20)) > 5}
self.attrs = {'shape': [2, 4, 20]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
class TestExpandV2Error(unittest.TestCase): class TestExpandV2Error(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -120,5 +120,29 @@ class TestFillConstantFP16(OpTest): ...@@ -120,5 +120,29 @@ class TestFillConstantFP16(OpTest):
self.check_output_with_place(self.place, atol=1e-3) self.check_output_with_place(self.place, atol=1e-3)
class TestFillConstantBool(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {
'shape': [123, 92],
'value': True,
'dtype': core.VarDesc.VarType.BOOL
}
self.outputs = {'Out': np.full((123, 92), True).astype(self.dtype)}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.BOOL
def test_check_output(self):
self.check_output_with_place(self.place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -271,5 +271,30 @@ class TestReduceMaxOpWithOutDtype_fp32_2(TestNPUReduceMaxOp): ...@@ -271,5 +271,30 @@ class TestReduceMaxOpWithOutDtype_fp32_2(TestNPUReduceMaxOp):
self.dtype = np.float16 self.dtype = np.float16
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMaxOpInt64(TestNPUReduceMaxOp):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_npu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.INT64)
}
self.outputs = {
'Out': self.inputs['X'].max(
axis=tuple(self.attrs['dim'])).astype(np.float32)
}
def init_dtype(self):
self.dtype = np.int64
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -39,7 +39,8 @@ class TestScale(OpTest): ...@@ -39,7 +39,8 @@ class TestScale(OpTest):
} }
self.attrs = {'scale': -2.3, 'bias': 0, 'bias_after_scale': True} self.attrs = {'scale': -2.3, 'bias': 0, 'bias_after_scale': True}
self.outputs = { self.outputs = {
'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) 'Out': (self.inputs['X'] *
self.dtype(self.attrs['scale'])).astype(self.dtype)
} }
def set_npu(self): def set_npu(self):
...@@ -57,6 +58,16 @@ class TestFP16Scale(TestScale): ...@@ -57,6 +58,16 @@ class TestFP16Scale(TestScale):
self.dtype = np.float16 self.dtype = np.float16
class TestScaleInt(TestScale):
def init_dtype(self):
self.dtype = np.int32
class TestScaleInt64(TestScale):
def init_dtype(self):
self.dtype = np.int64
class TestBiasAfterScale(OpTest): class TestBiasAfterScale(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册