From 64afa638cfa61a2e2716494ea38f54e32c22fc85 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 25 Aug 2022 19:19:30 +0800 Subject: [PATCH] [NPU] add run_program_op_npu (#45349) * [NPU] add run_program_op_npu * add run_program_op_npu ut --- paddle/fluid/framework/executor_cache.cc | 4 + paddle/fluid/operators/CMakeLists.txt | 2 +- .../elementwise/elementwise_add_op_npu.cc | 5 +- paddle/fluid/operators/run_program_op_npu.cc | 29 ++ .../unittests/npu/test_run_program_op_npu.py | 311 ++++++++++++++++++ 5 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/run_program_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 2df86e86a75..de3ca2131aa 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -50,6 +50,10 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) { execution_strategy.num_threads_ = 1; break; } + case platform::DeviceType::NPU: { + execution_strategy.num_threads_ = 1; + break; + } default: PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.", device_type)); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 33520eaa13e..a722aca6a9f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -106,7 +106,7 @@ set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel bac register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) -op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) +op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) target_link_libraries(run_program_op cuda_graph_with_memory_pool) op_library(quantize_linear_op DEPS phi) op_library(save_combine_op DEPS string_array) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc index 424b38b27b4..726b4186030 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -40,7 +40,10 @@ class ElementwiseAddNPUKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims = y->dims(); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - if (x_dims.size() >= y_dims.size()) { + + if (x_dims.size() == y_dims.size()) { + direct_compute = true; + } else if (x_dims.size() > y_dims.size()) { direct_compute = x_dims.size() == (y_dims.size() + axis); } else { direct_compute = y_dims.size() == (x_dims.size() + axis); diff --git a/paddle/fluid/operators/run_program_op_npu.cc b/paddle/fluid/operators/run_program_op_npu.cc new file mode 100644 index 00000000000..679d2fbdf5a --- /dev/null +++ b/paddle/fluid/operators/run_program_op_npu.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/operators/run_program_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +/* see [Why use single type kernel] */ +REGISTER_OP_NPU_KERNEL( + run_program, + ops::RunProgramOpKernel); +REGISTER_OP_NPU_KERNEL( + run_program_grad, + ops::RunProgramGradOpKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py new file mode 100644 index 00000000000..a0bb1ca8af3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six +import sys + +sys.path.append("..") + +from op_test import OpTest +import paddle +from paddle import _C_ops +import paddle.fluid as fluid +from paddle import compat as cpt +from paddle.fluid import core, framework, executor +from paddle.fluid.layers.utils import _hash_with_id +from paddle.fluid.framework import _in_eager_mode_ + +paddle.enable_static() +np.random.seed(1243) + + +@contextlib.contextmanager +def program_scope_guard(): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + with fluid.unique_name.guard(): + yield + + +# NOTE: Because RunProgramOp has a special output of type std::vector, +# the OpTest cannot be used in RunProgramOp. The variable type cannot be specified +# when creating output variables in OpTest, default type is LoDTensor +# NOTE: the gradient test method in OpTest also cannot be used for RunProgramOp, +# because it hold BlockDesc type attr, OperatorFactory can't parse this attr type +# when create Operator, so here compare gradients with static graph +# NOTE: Here rewrite a simple unittest framework for RunProgramOp +class RunProgramNPUOpTest(unittest.TestCase): + + def build_model(self): + raise NotImplementedError( + "RunProgramOp test should implement build_model") + + def check_output(self): + places = [fluid.NPUPlace(0)] + for place in places: + # TODO: RunProgramOp is not recommended for use in static mode now + self.expect_outs = self.run_static_model(place, is_test=True) + self.check_output_with_place(place) + + def check_grad(self): + places = [fluid.NPUPlace(0)] + for place in places: + # TODO: RunProgramOp is not recommended for use in static mode now + self.expect_grads = self.run_static_model(place, is_test=False) + self.check_grad_with_place(place) + + def run_static_model(self, place, is_test=True): + with program_scope_guard(): + startup_program = fluid.default_startup_program() + main_program = fluid.default_main_program() + + self.build_model() + + exe = fluid.Executor(place) + exe.run(startup_program) + + if is_test: + fetch_list = self.output_names['Out'] + else: + fetch_list = self.get_param_grad_names() + + outs = exe.run(main_program, + feed=self.inputs['X'], + fetch_list=fetch_list) + return outs + + def get_program_desc(self): + with program_scope_guard(): + fwd_op_num = self.build_model() + return fluid.default_main_program().desc, fwd_op_num + + def prepare_attrs(self): + return ('global_block', self.program_desc.block(0), 'start_op_index', 0, + 'end_op_index', self.fwd_op_num, 'program_id', + _hash_with_id(self.program_desc, self)) + + def get_param_grad_names(self): + grad_names = [] + for var_name in self.inputs['Params']: + grad_names.append(var_name + core.grad_var_suffix()) + return grad_names + + def check_output_with_place(self, place): + # Step 1. run op + actual_outs = self.calc_dygraph_output(place) + + # Step 2. compare output + for expect_v, actual_v in six.moves.zip(self.expect_outs, actual_outs): + np.testing.assert_allclose(expect_v, + actual_v.numpy(), + rtol=1e-05, + atol=1e-05) + + def check_grad_with_place(self, place): + # Step 1. calc grads + actual_grads = self.calc_dygraph_grad(place) + + # Step 2. compare grads + for expect_v, actual_v in six.moves.zip(self.expect_grads, + actual_grads): + np.testing.assert_array_almost_equal(expect_v, actual_v) + np.testing.assert_allclose(expect_v, + actual_v, + rtol=1e-05, + atol=1e-05) + + def prepare_dygraph_input(self, place, return_param_list=False): + + def create_var_base(is_input, name, np_value, stop_gradient): + if _in_eager_mode_: + var = core.eager.Tensor(value=np_value, + name=name, + place=place, + zero_copy=True) + else: + var = core.VarBase(value=np_value, + name=name, + place=place, + zero_copy=True) + var.stop_gradient = stop_gradient + return var + + # build inputs + inputs = {} + param_list = [] + inputs['X'] = [] + for name, np_value in self.inputs['X'].items(): + var = create_var_base(True, name, np_value, True) + inputs['X'].append(var) + inputs['Params'] = [] + for name, np_value in self.inputs['Params'].items(): + var = create_var_base(True, name, np_value, False) + inputs['Params'].append(var) + if return_param_list: + param_list.append(var) + + if return_param_list: + return inputs, param_list + return inputs + + def prepare_dygraph_output(self): + + def create_var_base(is_input, name): + var = framework._varbase_creator(dtype=None, shape=None, name=name) + var.stop_gradient = False + return var + + # build outputs + outputs = {} + outputs['Out'] = [] + for name in self.output_names['Out']: + outputs['Out'].append(create_var_base(False, name)) + + if _in_eager_mode_: + outputs['OutScope'] = [core.Scope()] + else: + outputs['OutScope'] = framework._varbase_creator( + type=core.VarDesc.VarType.STEP_SCOPES, + name="program_out_scope", + persistable=True) + inner_scope = core.Scope() + outputs['OutScope'].value().set_scope(inner_scope) + + outputs['DOut'] = [create_var_base(False, "Fake_var")] + return outputs + + def calc_dygraph_output(self, place): + self.program_desc, self.fwd_op_num = self.get_program_desc() + self.attrs = self.prepare_attrs() + + with fluid.dygraph.guard(place): + inputs = self.prepare_dygraph_input(place) + outputs = self.prepare_dygraph_output() + + _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], + outputs['OutScope'], outputs['DOut'], None, + *self.attrs) + return outputs['Out'] + + def calc_dygraph_grad(self, place): + self.program_desc, self.fwd_op_num = self.get_program_desc() + self.attrs = self.prepare_attrs() + + with fluid.dygraph.guard(place): + # Step 1. run forward + inputs, input_param_list = self.prepare_dygraph_input(place, True) + outputs = self.prepare_dygraph_output() + + _C_ops.run_program(inputs['X'], inputs['Params'], outputs['Out'], + outputs['OutScope'], outputs['DOut'], None, + *self.attrs) + + for param in input_param_list: + var_type = self._get_grad_vartype(param.name) + if var_type is None: + continue + param._set_grad_type(var_type) + + # Step 2. run backward + # NOTE: in unittest, only support single output now + actual_outs = outputs['Out'] + assert len(actual_outs) == 1 + actual_outs[0].backward() + + # Step 3. prepare grads + grads = [] + for param in input_param_list: + grad = param.gradient() + grads.append(grad) + return grads + + def _get_grad_vartype(self, name): + assert self.program_desc is not None + grad_name = name + core.grad_var_suffix() + for i in six.moves.range(self.program_desc.num_blocks()): + block = self.program_desc.block(i) + var_desc = block.find_var_recursive(cpt.to_bytes(grad_name)) + return var_desc.type() if var_desc is not None else None + + +class TestRunProgramOpWithFC(RunProgramNPUOpTest): + + def setUp(self): + self.op_type = "run_program" + self.dtype = np.float32 + self.input_names = { + 'X': ['img'], + 'Params': ['weight_param', 'bias_param'] + } + self.output_names = {'Out': ['fc_0.tmp_2']} + + self.inputs = { + 'X': { + self.input_names['X'][0]: + np.random.random((32, 1, 28, 28)).astype(self.dtype) + }, + 'Params': { + self.input_names['Params'][0]: + np.random.random((784, 10)).astype(self.dtype), + self.input_names['Params'][1]: + np.random.random((32, 10)).astype(self.dtype) + } + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad() + + def build_model(self): + # 1. simple model + img = fluid.data(name=self.input_names['X'][0], + shape=[None, 1, 28, 28], + dtype='float32') + weight_attr = fluid.ParamAttr( + name=self.input_names['Params'][0], + learning_rate=0.5, + initializer=fluid.initializer.NumpyArrayInitializer( + self.inputs['Params'][self.input_names['Params'][0]]), + trainable=True) + bias_attr = fluid.ParamAttr( + name=self.input_names['Params'][1], + learning_rate=0.5, + initializer=fluid.initializer.NumpyArrayInitializer( + self.inputs['Params'][self.input_names['Params'][1]]), + trainable=True) + pred = fluid.layers.fc(input=img, + size=10, + param_attr=weight_attr, + bias_attr=bias_attr, + act='relu') + # 2. get forward op num + fwd_op_num = fluid.default_main_program().global_block().desc.op_size() + # 3. append backward + grads = fluid.backward.gradients(targets=[pred], inputs=[img]) + + return fwd_op_num + + +if __name__ == "__main__": + unittest.main() -- GitLab