diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 20c56d6a279334c87245857fdcceb375c02aed7b..e0cb4dee5311afc296c277818ec2769b7782e1ac 100644 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -503,7 +503,6 @@ class SwishGradNPUKernel : public framework::OpKernel { beta_x.mutable_data(x->dims(), ctx.GetPlace()); sigmoid_out.mutable_data(x->dims(), ctx.GetPlace()); swish_out.mutable_data(x->dims(), ctx.GetPlace()); - const auto& muls_runner = NpuOpRunner("Muls", {*x}, {beta_x}, {{"value", beta}}); muls_runner.Run(stream); @@ -515,6 +514,9 @@ class SwishGradNPUKernel : public framework::OpKernel { const auto& mul_runner = NpuOpRunner("Mul", {sigmoid_out, *x}, {swish_out}, {}); mul_runner.Run(stream); + const auto& muls_runner2 = + NpuOpRunner("Muls", {swish_out}, {swish_out}, {{"value", beta}}); + muls_runner2.Run(stream); const auto& mul_runner1 = NpuOpRunner("Mul", {sigmoid_out, swish_out}, {*dx}, {}); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc index b2030ad21e8d1fb93eed2582b12aea4c0844c319..36a7d54f8c1c2eec3fcc5c4c1d5bb947079bb0b6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_npu.cc @@ -143,8 +143,16 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL(elementwise_mul, ops::ElementwiseMulNPUKernel, - ops::ElementwiseMulNPUKernel); + ops::ElementwiseMulNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ElementwiseMulNPUKernel, +#endif + ops::ElementwiseMulNPUKernel); REGISTER_OP_NPU_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradNPUKernel, - ops::ElementwiseMulGradNPUKernel); + ops::ElementwiseMulGradNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ElementwiseMulGradNPUKernel, +#endif + ops::ElementwiseMulGradNPUKernel); diff --git a/paddle/fluid/operators/expand_v2_op_npu.cc b/paddle/fluid/operators/expand_v2_op_npu.cc index 4b0e0770573a6f7091f1a0db7534e923eeb61d99..46385a20ab989278a429c6c044596d73bae6d8ad 100644 --- a/paddle/fluid/operators/expand_v2_op_npu.cc +++ b/paddle/fluid/operators/expand_v2_op_npu.cc @@ -106,11 +106,28 @@ class ExpandV2NPUKernel : public framework::OpKernel { Out->Resize(out_dims); Out->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner("ExpandD", {*X}, {*Out}, attr_input); - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); + const auto& dev_ctx = + ctx.template device_context(); + auto op_func = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = NpuOpRunner("ExpandD", inputs, outputs, attrs); + runner.Run(dev_ctx.stream()); + }; + + if (X->type() == framework::proto::VarType::BOOL) { + NpuOpRunner::TypeAdapter({*X}, {*Out}, attr_input, dev_ctx, op_func, + {framework::proto::VarType::UINT8}, + {framework::proto::VarType::UINT8}); + } else if (X->type() == framework::proto::VarType::INT64) { + NpuOpRunner::TypeAdapter({*X}, {*Out}, attr_input, dev_ctx, op_func, + {framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else { + const auto& runner = NpuOpRunner("ExpandD", {*X}, {*Out}, attr_input); + runner.Run(dev_ctx.stream()); + } } }; @@ -181,7 +198,9 @@ REGISTER_OP_NPU_KERNEL( ops::ExpandV2NPUKernel, ops::ExpandV2NPUKernel, - ops::ExpandV2NPUKernel); + ops::ExpandV2NPUKernel, + ops::ExpandV2NPUKernel, + ops::ExpandV2NPUKernel); REGISTER_OP_NPU_KERNEL( expand_v2_grad, diff --git a/paddle/fluid/operators/fill_constant_op_npu.cc b/paddle/fluid/operators/fill_constant_op_npu.cc index 16a2433f5cad6f26a64f043542253d09544ef17d..7241fcaf1878ff69ec22e4aab46f0f5936a254fb 100644 --- a/paddle/fluid/operators/fill_constant_op_npu.cc +++ b/paddle/fluid/operators/fill_constant_op_npu.cc @@ -22,13 +22,13 @@ namespace operators { template class FillConstantNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void Compute(const framework::ExecutionContext &ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); auto str_value = ctx.Attr("str_value"); auto float_value = ctx.Attr("value"); - auto* out_var = ctx.Output("Out"); + auto *out_var = ctx.Output("Out"); auto stream = ctx.template device_context() .stream(); @@ -59,28 +59,49 @@ class FillConstantNPUKernel : public framework::OpKernel { } auto shape = GetShape(ctx); - Tensor tensor_value(data_type); - tensor_value.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&tensor_value, value); - out_var->mutable_data(shape, ctx.GetPlace()); - - NpuOpRunner runner; + if (data_type != framework::proto::VarType::BOOL) { + Tensor tensor_value(data_type); + tensor_value.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&tensor_value, value); + NpuOpRunner runner; #if (CANN_VERSION_CODE >= 503003) - runner.SetType("FillD") - .AddInput(tensor_value) - .AddOutput(*out_var) - .AddAttrs( - {{ "dims", - framework::vectorize(shape) }}) - .Run(stream); + runner.SetType("FillD") + .AddInput(tensor_value) + .AddOutput(*out_var) + .AddAttrs( + {{ "dims", + framework::vectorize(shape) }}) + .Run(stream); #else - runner.SetType("Fill") - .AddInput(framework::vectorize(shape)) - .AddInput(tensor_value) - .AddOutput(*out_var) - .Run(stream); + runner.SetType("Fill") + .AddInput(framework::vectorize(shape)) + .AddInput(tensor_value) + .AddOutput(*out_var) + .Run(stream); #endif + } else { + const auto &dev_ctx = + ctx.template device_context(); + auto op_func = [&shape, &value]( + const std::vector &inputs, const std::vector &outputs, + const NPUAttributeMap &attrs, + const platform::NPUDeviceContext &dev_ctx) { + Tensor tensor_value; + tensor_value.mutable_data({1}, dev_ctx.GetPlace()); + FillNpuTensorWithConstant(&tensor_value, + static_cast(value)); + + NpuOpRunner runner; + runner.SetType("Fill") + .AddInput(framework::vectorize(shape)) + .AddInput(tensor_value) + .AddOutput(outputs[0]) + .Run(dev_ctx.stream()); + }; + NpuOpRunner::TypeAdapter({}, {*out_var}, {}, dev_ctx, op_func, {}, + {framework::proto::VarType::UINT8}); + } } }; } // namespace operators diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index 830e18cb8a14c09a9f999bc11d3bbda08e31f1fc..e104fc157d6f0549339183a7b96d1eb7c657c85b 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -436,5 +436,67 @@ void NpuOpRunner::Run(aclrtStream stream) const { PADDLE_ENFORCE_NPU_SUCCESS(ret); } +void NpuOpRunner::TypeAdapter( + const std::vector &inputs, const std::vector &outputs, + const NPUAttributeMap &attrs, const platform::NPUDeviceContext &dev_ctx, + std::function &, const std::vector &, + const NPUAttributeMap &, + const platform::NPUDeviceContext &)> + op_runner, + const std::vector &input_type, + const std::vector &output_type) { + PADDLE_ENFORCE_EQ( + inputs.size(), input_type.size(), + platform::errors::InvalidArgument( + "The number of inputs must be equal to input_type.size().")); + PADDLE_ENFORCE_EQ( + outputs.size(), output_type.size(), + platform::errors::InvalidArgument( + "The number of outputs must be equal to output_type.size().")); + + std::vector tmp_inputs(inputs.size()); + std::vector tmp_outputs(outputs.size()); + + for (size_t i = 0; i < input_type.size(); ++i) { + bool cast_input = + (input_type[i] == -1 || input_type[i] != inputs[i].type()); + if (!cast_input) { + tmp_inputs[i].ShareDataWith(inputs[i]); + } else { + tmp_inputs[i].Resize(inputs[i].dims()); + tmp_inputs[i].mutable_data(dev_ctx.GetPlace(), input_type[i]); + + const auto &cast_runner = NpuOpRunner( + "Cast", {inputs[i]}, {tmp_inputs[i]}, + {{"dst_type", static_cast(ConvertToNpuDtype(input_type[i]))}}); + cast_runner.Run(dev_ctx.stream()); + } + } + for (size_t i = 0; i < output_type.size(); ++i) { + bool cast_output = + (output_type[i] == -1 || output_type[i] != outputs[i].type()); + if (!cast_output) { + tmp_outputs[i].ShareDataWith(outputs[i]); + } else { + tmp_outputs[i].Resize(outputs[i].dims()); + tmp_outputs[i].mutable_data(dev_ctx.GetPlace(), output_type[i]); + } + } + + op_runner(tmp_inputs, tmp_outputs, attrs, dev_ctx); + + for (size_t i = 0; i < output_type.size(); ++i) { + bool cast_output = + (output_type[i] == -1 || output_type[i] != outputs[i].type()); + if (cast_output) { + const auto &cast_runner = NpuOpRunner( + "Cast", {tmp_outputs[i]}, {outputs[i]}, + {{"dst_type", + static_cast(ConvertToNpuDtype(outputs[i].type()))}}); + cast_runner.Run(dev_ctx.stream()); + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/npu_op_runner.h b/paddle/fluid/operators/npu_op_runner.h index 6db5f17d67118166b5d8a8a461c98ca83b79b782..a4a3786b5da53ae36130584559f6b45e07918996 100644 --- a/paddle/fluid/operators/npu_op_runner.h +++ b/paddle/fluid/operators/npu_op_runner.h @@ -103,6 +103,16 @@ class NpuOpRunner { void Run(aclrtStream stream = nullptr) const; + static void TypeAdapter( + const std::vector &inputs, const std::vector &outputs, + const NPUAttributeMap &attrs, const platform::NPUDeviceContext &dev_ctx, + std::function &, + const std::vector &, const NPUAttributeMap &, + const platform::NPUDeviceContext &)> + op_runner, + const std::vector &input_type, + const std::vector &output_type); + private: aclTensorDesc *CreateTensorDesc(Tensor tensor, aclMemType mem_type = ACL_MEMTYPE_DEVICE); diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc index 5efc7e9b869b7d921b0b510c8da843b510102d1a..68417cdad50c006c5e568ccc6475ed65dbec9624 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc @@ -73,20 +73,33 @@ class ReduceMaxNPUKernel : public framework::OpKernel { attr_input = {{"axes", dim_vec}, {"keep_dims", keep_dim}}; } - auto stream = - ctx.template device_context() - .stream(); - - const auto& runner = - NpuOpRunner("ReduceMaxD", {*x}, {cast_out}, attr_input); - runner.Run(stream); + const auto& dev_ctx = + ctx.template device_context(); + if (x->type() == framework::proto::VarType::INT64) { + auto op_func = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = + NpuOpRunner("ReduceMaxD", {inputs[0]}, {outputs[0]}, attrs); + runner.Run(dev_ctx.stream()); + }; + + NpuOpRunner::TypeAdapter({*x}, {cast_out}, attr_input, dev_ctx, op_func, + {framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else { + const auto& runner = + NpuOpRunner("ReduceMaxD", {*x}, {cast_out}, attr_input); + runner.Run(dev_ctx.stream()); + } if (x->type() != cast_out_dtype) { auto dst_dtype = ConvertToNpuDtype(cast_out_dtype); const auto& runner_cast = NpuOpRunner("Cast", {cast_out}, {*out}, {{"dst_type", static_cast(dst_dtype)}}); - runner_cast.Run(stream); + runner_cast.Run(dev_ctx.stream()); } } }; @@ -98,4 +111,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( reduce_max, ops::ReduceMaxNPUKernel, - ops::ReduceMaxNPUKernel); + ops::ReduceMaxNPUKernel, + ops::ReduceMaxNPUKernel, + ops::ReduceMaxNPUKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_npu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_npu.cc index 78bd42ff00c83f409d1ec3d094ab8a03a2a68eb2..33fcdbce9d0eeb42af62f287f66d97911e077bf9 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op_npu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_npu.cc @@ -142,12 +142,18 @@ namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( reduce_sum, ops::ReduceSumNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ReduceSumNPUKernel, +#endif ops::ReduceSumNPUKernel, ops::ReduceSumNPUKernel); REGISTER_OP_NPU_KERNEL( reduce_sum_grad, ops::ReduceSumGradNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ReduceSumGradNPUKernel, +#endif ops::ReduceSumGradNPUKernel, ops::ReduceSumGradNPUKernel); diff --git a/paddle/fluid/operators/scale_op_npu.cc b/paddle/fluid/operators/scale_op_npu.cc index 744a9b137f622e263e4b369a1e195d65ccf8cacb..c2f320ed684b889456249b9d039a46fba7a86af1 100644 --- a/paddle/fluid/operators/scale_op_npu.cc +++ b/paddle/fluid/operators/scale_op_npu.cc @@ -37,15 +37,47 @@ class ScaleNPUKernel : public framework::OpKernel { auto* scale_tensor = ctx.Input("ScaleTensor"); scale = static_cast(GetAttrFromTensor(scale_tensor)); } - + if (isinf(scale)) { + if (signbit(scale)) { + scale = -std::numeric_limits::max(); + } else { + scale = std::numeric_limits::max(); + } + } if (!bias_after_scale) { bias *= scale; } out->mutable_data(ctx.GetPlace()); - const auto& runner = - NpuOpRunner("Power", {*x}, {*out}, - {{"power", power}, {"scale", scale}, {"shift", bias}}); - runner.Run(stream); + + framework::NPUAttributeMap attrs = { + {"power", power}, {"scale", scale}, {"shift", bias}}; + const auto& dev_ctx = + ctx.template device_context(); + auto op_func = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& muls_runner = NpuOpRunner("Muls", {inputs[0]}, {outputs[0]}, + {{"value", attrs.at("scale")}}); + muls_runner.Run(dev_ctx.stream()); + + const auto& adds_runner = NpuOpRunner("Adds", {outputs[0]}, {outputs[0]}, + {{"value", attrs.at("shift")}}); + adds_runner.Run(dev_ctx.stream()); + }; + + if (x->type() == framework::proto::VarType::INT32) { + NpuOpRunner::TypeAdapter({*x}, {*out}, attrs, dev_ctx, op_func, + {framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else if (x->type() == framework::proto::VarType::INT64) { + NpuOpRunner::TypeAdapter({*x}, {*out}, attrs, dev_ctx, op_func, + {framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else { + const auto& runner = NpuOpRunner("Power", {*x}, {*out}, attrs); + runner.Run(stream); + } } }; @@ -54,4 +86,6 @@ class ScaleNPUKernel : public framework::OpKernel { REGISTER_OP_NPU_KERNEL( scale, paddle::operators::ScaleNPUKernel, - paddle::operators::ScaleNPUKernel); + paddle::operators::ScaleNPUKernel, + paddle::operators::ScaleNPUKernel, + paddle::operators::ScaleNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py index d48d2a8430134a623e587028316be4133d5b24ca..fd0b9850308b2671b5efa003c28df79338b56877 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py @@ -201,13 +201,16 @@ class TestExpandV2OpFloat(OpTest): # Situation 5: input x is int32 # skip grad check for int32 class TestExpandV2OpInteger(OpTest): + def init_dtype(self): + self.dtype = 'int32' + def setUp(self): self.set_npu() self.place = paddle.NPUPlace(0) self.op_type = "expand_v2" self.inputs = { 'X': np.random.randint( - 10, size=(2, 4, 20)).astype("int32") + 10, size=(2, 4, 20)).astype(self.dtype) } self.attrs = {'shape': [2, 4, 20]} output = np.tile(self.inputs['X'], (1, 1, 1)) @@ -221,6 +224,25 @@ class TestExpandV2OpInteger(OpTest): self.check_output_with_place(self.place) +class TesstExpandV2OpInt64(TestExpandV2OpInteger): + def init_dtype(self): + self.dtype = 'int64' + + +class TesstExpandV2OpBool(TestExpandV2OpInteger): + def init_dtype(self): + self.dtype = 'bool' + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.inputs = {'X': np.random.randint(10, size=(2, 4, 20)) > 5} + self.attrs = {'shape': [2, 4, 20]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + class TestExpandV2Error(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py index 2ab15213803a9013d44eb3e1e98a0b286627c6a6..a3e781c990ecb151b1b7a5c80bc256324b6fb43e 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py @@ -120,5 +120,29 @@ class TestFillConstantFP16(OpTest): self.check_output_with_place(self.place, atol=1e-3) +class TestFillConstantBool(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = { + 'shape': [123, 92], + 'value': True, + 'dtype': core.VarDesc.VarType.BOOL + } + self.outputs = {'Out': np.full((123, 92), True).astype(self.dtype)} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.BOOL + + def test_check_output(self): + self.check_output_with_place(self.place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_reduce_max_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_reduce_max_op_npu.py index f6c346159b8bee25e242c128412a0a36c78f4f1f..68a28ea72e1fc091b04914de19533534962b0885 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_reduce_max_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_reduce_max_op_npu.py @@ -271,5 +271,30 @@ class TestReduceMaxOpWithOutDtype_fp32_2(TestNPUReduceMaxOp): self.dtype = np.float16 +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxOpInt64(TestNPUReduceMaxOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.set_npu() + self.init_dtype() + + self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)} + self.attrs = { + 'dim': [-2, -1], + 'out_dtype': int(core.VarDesc.VarType.INT64) + } + self.outputs = { + 'Out': self.inputs['X'].max( + axis=tuple(self.attrs['dim'])).astype(np.float32) + } + + def init_dtype(self): + self.dtype = np.int64 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_scale_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_scale_op_npu.py index 65ec28fbf7d3a394ce0ac93c81a651445d584c75..424c4ca0ff35d304912e1989e04bc573512a3e65 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_scale_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_scale_op_npu.py @@ -39,7 +39,8 @@ class TestScale(OpTest): } self.attrs = {'scale': -2.3, 'bias': 0, 'bias_after_scale': True} self.outputs = { - 'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) + 'Out': (self.inputs['X'] * + self.dtype(self.attrs['scale'])).astype(self.dtype) } def set_npu(self): @@ -57,6 +58,16 @@ class TestFP16Scale(TestScale): self.dtype = np.float16 +class TestScaleInt(TestScale): + def init_dtype(self): + self.dtype = np.int32 + + +class TestScaleInt64(TestScale): + def init_dtype(self): + self.dtype = np.int64 + + class TestBiasAfterScale(OpTest): def setUp(self): self.set_npu()