diff --git a/cmake/operators.cmake b/cmake/operators.cmake index e58dbf77b4c9c83bd131404d61181adf89fec305..8469dc4c02ee37b333254d6d35b0eb48354d4b86 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -335,6 +335,17 @@ function(op_library TARGET) endif() endforeach() + # pybind USE_OP_DEVICE_KERNEL for ROCm + list (APPEND hip_srcs ${hip_cc_srcs}) + # message("hip_srcs ${hip_srcs}") + foreach(hip_src ${hip_srcs}) + set(op_name "") + find_register(${hip_src} "REGISTER_OP_CUDA_KERNEL" op_name) + if(NOT ${op_name} EQUAL "") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") + set(pybind_flag 1) + endif() + endforeach() # pybind USE_OP_DEVICE_KERNEL for CUDNN/MIOPEN list(APPEND cudnn_cu_srcs ${cudnn_cu_cc_srcs}) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 3c629b2145e141b4251d5b8316d980d8eea3985e..7e7114111c4e1d1ca6f7a4cbafa183284248b854 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -17,6 +17,12 @@ import re import argparse import os +# For API dispatch used at python-level +# { op_name : [arg_name, ...] } +core_ops_returns_info = {} +core_ops_args_info = {} +core_ops_args_type_info = {} + def ParseArguments(): parser = argparse.ArgumentParser( @@ -130,17 +136,16 @@ def ParseYamlArgs(string): attrs_list = [] args = [x.strip() for x in string.strip().split(",")] - atype = r'((const )?\S+) ' - aname = r'(\S+)' + aname = r'(.*)' pattern = f'{atype}{aname}' for i in range(len(args)): arg = args[i] m = re.search(pattern, arg) - arg_type = m.group(1) - arg_name = m.group(3).split("=")[0] - default_value = m.group(3).split("=")[1] if len(m.group(3).split( - "=")) > 1 else None + arg_type = m.group(1).strip() + arg_name = m.group(3).split("=")[0].strip() + default_value = m.group(3).split("=")[1].strip() if len( + m.group(3).split("=")) > 1 else None if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -262,7 +267,6 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, forward_attr_type = forward_attrs_list[i][1] forward_attr_default = forward_attrs_list[i][2] forward_attr_pos = forward_attrs_list[i][3] - assert orig_attr_type == forward_attr_type assert orig_attr_default == forward_attr_default assert orig_attr_pos == forward_attr_pos @@ -741,26 +745,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, # Get Function Args num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys( )) - inputs_args_list = ["" for i in range(num_inputs)] + inputs_args_definition_list = ["" for i in range(num_inputs)] + inputs_args_declaration_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)] for name, (ttype, pos) in forward_inputs_position_map.items(): inputs_call_list[pos] = f"{name}" if IsPlainTensorType(ttype): - inputs_args_list[ + inputs_args_definition_list[ + pos] = f"const paddle::experimental::Tensor& {name}" + inputs_args_declaration_list[ pos] = f"const paddle::experimental::Tensor& {name}" else: assert IsVectorTensorType(ttype) - inputs_args_list[ + inputs_args_definition_list[ + pos] = f"const std::vector& {name}" + inputs_args_declaration_list[ pos] = f"const std::vector& {name}" for name, atype, default_val, pos in forward_attrs_list: inputs_call_list[pos] = name if default_val is not None: - inputs_args_list[pos] = f"{atype} {name} = {default_val}" + inputs_args_declaration_list[ + pos] = f"{atype} {name} = {default_val}" else: - inputs_args_list[pos] = f"{atype} {name}" + inputs_args_declaration_list[pos] = f"{atype} {name}" + inputs_args_definition_list[pos] = f"{atype} {name}" - inputs_args_str = ", ".join(inputs_args_list) + inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) + inputs_args_definition_str = ", ".join(inputs_args_definition_list) inputs_call_args_str = ", ".join(inputs_call_list) # Forward Full Logic @@ -812,13 +824,95 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, forward_function_name = GetForwardFunctionName(fwd_api_name) forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( - returns_type_str, forward_function_name, inputs_args_str, + returns_type_str, forward_function_name, inputs_args_definition_str, forward_call_str, node_creation_str, returns_str) - forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});" + forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});" return forward_function_str, forward_function_declaration_str +def CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, + forward_outputs_position_map, forward_attrs_list): + # fwd_api_name : "" + # forward_inputs_position_map = { "name" : [type, fwd_position] } + # forward_outputs_position_map = { "name" : [type, fwd_position] } + # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] + num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list) + num_returns = len(forward_outputs_position_map.keys()) + + final_state_fwd_api_name = "final_state_" + fwd_api_name + core_ops_returns_info[ + final_state_fwd_api_name] = ["" for i in range(num_returns)] + core_ops_args_info[final_state_fwd_api_name] = ["" for i in range(num_args)] + core_ops_args_type_info[ + final_state_fwd_api_name] = ["" for i in range(num_args)] + for name, (ttype, pos) in forward_inputs_position_map.items(): + core_ops_args_info[final_state_fwd_api_name][pos] = name + if IsPlainTensorType(ttype): + core_ops_args_type_info[final_state_fwd_api_name][pos] = "tensor" + else: + assert IsVectorTensorType(ttype) + core_ops_args_type_info[final_state_fwd_api_name][pos] = "list" + + for name, _, _, pos in forward_attrs_list: + core_ops_args_info[final_state_fwd_api_name][pos] = name + + for name, (ttype, pos) in forward_outputs_position_map.items(): + core_ops_returns_info[final_state_fwd_api_name][pos] = name + + +def GenerateCoreOpInfoDeclaration(): + core_ops_declaration_str = """ + extern std::unordered_map> core_ops_final_state_args_info; + extern std::unordered_map> core_ops_final_state_args_type_info; + extern std::unordered_map> core_ops_final_state_returns_info; + +""" + return core_ops_declaration_str + + +def GenerateCoreOpInfoDefinition(): + + CORE_OPS_INFO_TEMPLATE = """ +std::unordered_map> core_ops_final_state_args_info = {{ + {} +}}; +std::unordered_map> core_ops_final_state_args_type_info = {{ + {} +}}; +std::unordered_map> core_ops_final_state_returns_info = {{ + {} +}}; + +""" + op_args_info_list = [] + for op_name, arg_list in core_ops_args_info.items(): + arg_str = ",".join(["\"" + v + "\"" for v in arg_list]) + op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }}," + op_args_info_list.append(op_args_info) + + op_types_info_list = [] + for op_name, type_list in core_ops_args_type_info.items(): + type_str = ",".join(["\"" + v + "\"" for v in type_list]) + op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }}," + op_types_info_list.append(op_types_info) + + op_returns_info_list = [] + for op_name, return_list in core_ops_returns_info.items(): + return_str = ",".join(["\"" + v + "\"" for v in return_list]) + return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }}," + op_returns_info_list.append(return_types_info) + + op_args_info_str = "\n".join(op_args_info_list) + op_types_info_str = "\n".join(op_types_info_list) + op_returns_info_str = "\n".join(op_returns_info_list) + + core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format( + op_args_info_str, op_types_info_str, op_returns_info_str) + + return core_ops_info_definition_str + + def GenerateNodeCCFile(filepath, node_definition_str): file_contents = """ #include "glog/logging.h" @@ -856,6 +950,8 @@ def GenerateForwardCCFile(filepath, forward_definition_str): #include "paddle/fluid/eager/api/utils/global_utils.h" """ + + file_contents += GenerateCoreOpInfoDefinition() file_contents += forward_definition_str with open(filepath, 'a') as f: f.write(file_contents) @@ -871,6 +967,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str): #include "paddle/fluid/framework/op_registry.h" """ + file_contents += GenerateCoreOpInfoDeclaration() file_contents += forward_function_declaration_str with open(filepath, 'a') as f: f.write(file_contents) @@ -985,6 +1082,11 @@ if __name__ == "__main__": forward_definition_str += definition_declaration_pair[0] forward_declaration_str += definition_declaration_pair[1] + # For python-level API dispatch + CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, + forward_outputs_position_map, + forward_attrs_list) + # Generate Files nodes_h_path = args.nodes_h_path nodes_cc_path = args.nodes_cc_path diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 60b615b50dae71e30918028f3a704e4880a19d37..f7945551ad9d465c2d297383739c868131b2a29c 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -104,6 +104,8 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj PyThreadState *tstate = nullptr; try {{ + VLOG(6) << "Running Eager Final State API: {}"; + // Get EagerTensors from args {} @@ -129,16 +131,87 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj """ python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( - fwd_api_name, get_eager_tensor_str, parse_attributes_str, + fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str, GetForwardFunctionName(fwd_api_name), dygraph_function_call_str) - python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}" + python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}},\n" return python_c_function_str, python_c_function_reg_str +def GenerateCoreOpsInfoMap(): + result = """ +static PyObject * eager_get_final_state_core_ops_args_info(PyObject *self) { + PyThreadState *tstate = nullptr; + try + { + return ToPyObject(core_ops_final_state_args_info); + } + catch(...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyObject * eager_get_final_state_core_ops_args_type_info(PyObject *self) { + PyThreadState *tstate = nullptr; + try + { + return ToPyObject(core_ops_final_state_args_type_info); + } + catch(...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyObject * eager_get_final_state_core_ops_returns_info(PyObject *self) { + PyThreadState *tstate = nullptr; + try + { + return ToPyObject(core_ops_final_state_returns_info); + } + catch(...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + """ + + core_ops_infos_registry = """ + {\"get_final_state_core_ops_args_info\", + (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_info, METH_NOARGS, + \"C++ interface function for eager_get_final_state_core_ops_args_info.\"}, + {\"get_final_state_core_ops_args_type_info\", + (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_args_type_info, + METH_NOARGS, + \"C++ interface function for eager_get_final_state_core_ops_args_type_info.\"}, + {\"get_final_state_core_ops_returns_info\", + (PyCFunction)(void(*)(void))eager_get_final_state_core_ops_returns_info, + METH_NOARGS, \"C++ interface function for eager_get_final_state_core_ops_returns_info.\"}, +""" + + return result, core_ops_infos_registry + + def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): + core_ops_infos_definition, core_ops_infos_registry = GenerateCoreOpsInfoMap( + ) + + python_c_function_str += core_ops_infos_definition + python_c_function_reg_str += core_ops_infos_registry + python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}" + PYTHON_C_WRAPPER_TEMPLATE = """ #pragma once @@ -215,12 +288,12 @@ if __name__ == "__main__": python_c_function_reg_list.append(python_c_function_reg_str) print("Generated Python-C Function: ", python_c_function_str) - python_c_function_reg_list.append("{nullptr,nullptr,0,nullptr}") python_c_functions_str = "\n".join(python_c_function_list) python_c_functions_reg_str = ",\n".join(python_c_function_reg_list) python_c_str = GeneratePythonCWrappers(python_c_functions_str, python_c_functions_reg_str) + print("Generated Python-C Codes: ", python_c_str) output_path = args.output_path diff --git a/paddle/fluid/framework/custom_kernel_test.cc b/paddle/fluid/framework/custom_kernel_test.cc index 29072551c80768d2313ba4399952c21644e079e6..63dd583504d601813d8e546d0c3e8611012ae2af 100644 --- a/paddle/fluid/framework/custom_kernel_test.cc +++ b/paddle/fluid/framework/custom_kernel_test.cc @@ -22,9 +22,12 @@ limitations under the License. */ #include #include #include "paddle/extension.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_kernel_info_helper.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/pten/api/lib/utils/allocator.h" -#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_factory.h" @@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) { paddle::platform::CPUPlace()); auto dense_x = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8, - paddle::framework::make_ddim({2, 3}), + pten::framework::make_ddim({2, 3}), pten::DataLayout::NCHW)); auto* dense_x_data = dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8, - paddle::framework::make_ddim({2, 3}), + pten::framework::make_ddim({2, 3}), pten::DataLayout::NCHW)); auto* dense_y_data = dense_y->mutable_data(paddle::platform::CPUPlace()); @@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) { pten::DataType fake_attr_dtype = pten::DataType::UINT32; paddle::framework::LoDTensor tmp_tensor; tmp_tensor.mutable_data({1}, pten::TransToPtenPlace(backend)); - pten::Scalar fake_attr_scalar = - paddle::experimental::MakePtenScalar(tmp_tensor); + pten::Scalar fake_attr_scalar{tmp_tensor}; pten::ScalarArray fake_attr_scalar_array; std::vector fake_attr_int64_vec; std::vector fake_attr_int_vec; diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 652286ab2666e6253173f6b7d5c3751a22ee788c..efe2423b0e8a5abdcccf2db823023e1037f6ec8b 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } + bool HasAttr(const std::string& name) const override { + return ctx_.HasAttr(name); + } + paddle::any Attr(const std::string& name) const override { auto& attr = ctx_.Attrs().GetAttr(name); return GetAttrValue(attr); @@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, pten::InferMetaContext infer_meta_context(ctx->IsRuntime()); auto& input_names = std::get<0>(signature.args); + auto& attr_names = std::get<1>(signature.args); auto& output_names = std::get<2>(signature.args); - // TODO(chenweihang): support attrs in next pr - // auto& attr_names = std::get<1>(signature.args); - // TODO(chenweihang): support multiple inputs and outputs + // TODO(chenweihang): support multiple inputs and outputs later pten::InferMetaContext infer_mete_context; for (auto& in_name : input_names) { - infer_meta_context.EmplaceBackInput(std::make_shared( - ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); + if (ctx->HasInput(in_name)) { + infer_meta_context.EmplaceBackInput(std::make_shared( + ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); + } else { + infer_meta_context.EmplaceBackInput({nullptr}); + } } + + auto attr_reader = ctx->Attrs(); + for (auto& attr_name : attr_names) { + if (ctx->HasAttr(attr_name)) { + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(float))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + } else { + // do nothing, skip useless attrs now + // TODO(chenweihang): support other attr type later and throw error + // if attr is cannot parsed + } + } else { + // do nothing + } + } + for (auto& out_name : output_names) { - infer_meta_context.EmplaceBackOutput(std::make_shared( - ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + if (ctx->HasOutput(out_name)) { + infer_meta_context.EmplaceBackOutput(std::make_shared( + ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + } else { + infer_meta_context.EmplaceBackOutput({nullptr}); + } } - // TODO(chenweihang): support attrs later return infer_meta_context; } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 518d866afccde35e283bdbd2e45fed3e1b480aec..8a5ec83b8b364addc08cbf0047f91012f63fda54 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -475,12 +475,11 @@ void InterpreterCore::ExecuteInstructionList( if (UNLIKELY(exception_holder_.IsCaught())) { VLOG(1) << "Exception caught " << exception_holder_.Type(); - // NOTE(xiongkun) Why we reset ? - // The caught exception may be EOFExcetion, under this situation, we need - // make async_work_queue_ available, so we need reset. - async_work_queue_->Cancel(); - async_work_queue_.reset(new interpreter::AsyncWorkQueue( - kHostNumThreads, &main_thread_blocker_)); + // Graceful exit when the executor encountered a fatal error. + // EOF is not a fatal error. + if (exception_holder_.Type() != "EOF") { + async_work_queue_->Cancel(); + } PADDLE_ENFORCE_EQ( main_thread_blocker_.Clear(), 0, platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index c72cbda008f3baf24c712ca3b35e68fb25f0ea06..67d60975c95d9a39676b9bb7cd7ee482c9fab3cd 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput( return out[0] != nullptr; } +bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const { + return op_.HasAttr(name); +} + bool InterpretercoreInferShapeContext::HasInputs( const std::string& name) const { const auto& ins = ctx_.inputs; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index b61b8af1e4a1b38f3db686e3b438aaf7745ed3c0..e00b1daf28a9ff469bfd0b81ca620161844f94b4 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { bool HasOutput(const std::string& name) const override; + bool HasAttr(const std::string& name) const override; + bool HasInputs(const std::string& name) const override; bool HasOutputs(const std::string& name) const override; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 7bceeb05bac599ef7150435eff5ed67b4076e846..942beb6e9a885283293b1823cc4e1b89bf1905d0 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutput(const std::string &name) const override; + bool HasAttr(const std::string &name) const override; + bool HasInputs(const std::string &name) const override; bool HasOutputs(const std::string &name) const override; @@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const { return block_.HasVarRecursive(output_names[0]); } +bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const { + return op_.HasAttr(name); +} + bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { if (op_.Inputs().find(name) == op_.Inputs().end()) { return false; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9cfdd7ee85bfc2e4647f47482bdc27cc734caf2f..dd76b0e7d7d492c7cff3a7fb16ec6cd5251b16ce 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext { return out[0] != nullptr; } + bool HasAttr(const std::string& name) const override { + return op_.HasAttr(name); + } + bool HasInputs(const std::string& name) const override { const auto& ins = ctx_.inputs; auto it = ins.find(name); @@ -2099,6 +2103,10 @@ void OperatorWithKernel::BuildPtenKernelContext( std::type_index(typeid(std::vector))) { pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray( BOOST_GET_CONST(std::vector, attr_iter->second)))); + } else if (std::type_index(attr_iter->second.type()) == + std::type_index(typeid(int32_t))) { + pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray( + &BOOST_GET_CONST(int32_t, attr_iter->second), 1))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to ScalarArray when " diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 2294d67fbf2f32f59b89b1164823c07c7e08bd39..db529bd17f4ab8c40b555c0be83fa6cd3daa1f14 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } + bool HasAttr(const std::string& name) const override { + return ctx_.HasAttr(name); + } + paddle::any Attr(const std::string& name) const override { auto& attr = ctx_.GetAttr(name); return GetAttrValue(attr); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 791600b39c3d94a911f6eae9113dc703392a7e55..09568168d8526a8352700e7edb2b2bae181eba20 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -61,6 +61,7 @@ class InferShapeContext { virtual ~InferShapeContext() = default; virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; + virtual bool HasAttr(const std::string &name) const = 0; virtual std::vector GetInputsVarType( const std::string &name) const = 0; diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 7033b9c11712dcefd49d42894fec6283eb064c9f..554657c71387b1713ec1d70526e723d2bf7a3cac 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return out[0] != nullptr; } + bool HasAttr(const std::string& name) const override { + return attrs_->count(name) > 0 || default_attrs_->count(name) > 0; + } + bool HasInputs(const std::string& name) const override { auto it = var_map_in_->find(name); if (it == var_map_in_->end() || it->second.empty()) { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 568a094a8f9250c9eecc5eabce173128b1f938f5..ab13e19171d1b0099520bbd445742e2c2c6b7ee5 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext( std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(std::move( pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int64_t))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int32_t))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1))); } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index b4ff3cff38217a57c0b1091c3e003043ca4c9673..fa52aa6d0af61578e18d51e8b95c13b5d383c858 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { } // namespace imperative } // namespace paddle -USE_OP(split); +USE_OP_ITSELF(split); USE_OP(relu); #ifdef PADDLE_WITH_MKLDNN USE_OP_DEVICE_KERNEL(relu, MKLDNN); diff --git a/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc b/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4061254edb1b712555d57668263bfc78086c3e9 --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOPMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CNCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int numel = x->numel(); + cnclDataType_t dtype = platform::ToCNCLDataType(x->type()); + + int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::CNCLCommContext::Instance().Get(rid, place); + + mluStream stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int root = ctx.Attr("root"); + if (root == comm->rank()) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnclBcast(reinterpret_cast(const_cast(x->data())), + numel, dtype, root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + + if (out != x) { + framework::TensorCopy( + *static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + PADDLE_ENFORCE_MLU_SUCCESS(cnclBcast(out->mutable_data(place), numel, + dtype, root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " + << framework::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with MLU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(c_broadcast, ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel); diff --git a/paddle/fluid/operators/gather_nd_op_xpu.cc b/paddle/fluid/operators/gather_nd_op_xpu.cc index c7e4169865fa6158934ca9a93d99b488ff9c0286..a86731bba8ab8560592867e04a9ff6ac469f2f94 100644 --- a/paddle/fluid/operators/gather_nd_op_xpu.cc +++ b/paddle/fluid/operators/gather_nd_op_xpu.cc @@ -47,8 +47,12 @@ class GatherNdXPUKernel : public framework::OpKernel { auto x_shape = paddle::framework::vectorize(x->dims()); auto index_shape = paddle::framework::vectorize(index->dims()); + if (index_shape.size() == 1) { + index_shape.insert(index_shape.begin(), 1); + } xpu::VectorParam x_vec = {x_shape.data(), static_cast(x_shape.size()), nullptr}; + auto &dev_ctx = ctx.template device_context(); int ret = XPU_SUCCESS; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 5add86f5b3c74ac82ea4f4fbb0c8c1c9cb0d00f6..f7a5e2a8af409500ea7d51dc83ef32c5aad9142a 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -16,6 +16,10 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/backward.h" + namespace paddle { namespace operators { @@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2"); - OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2"); - OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "matmul_v2"); - auto x_dims = context->GetInputDim("X"); - auto y_dims = context->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (context->HasOutput(x_grad_name)) { - context->SetOutputDim(x_grad_name, x_dims); - } - if (context->HasOutput(y_grad_name)) { - context->SetOutputDim(y_grad_name, y_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( @@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ops::MatMulV2GradOpMaker, ops::MatMulV2GradOpMaker); +DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor, + PT_INFER_META(pten::MatmulGradInferMeta)); REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, - ops::MatMulV2OpDoubleGradMaker); + ops::MatMulV2OpDoubleGradMaker, + MatMulV2GradInferShapeFunctor); REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad, ops::MatMulV2OpTripleGradMaker, diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 5bd699e08abbcad5524a500c28ec7d7768dc18f0..79636aced0333284d00f0bbcd96616c01ffd88c9 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -172,11 +172,3 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker, ops::SplitGradMaker); -namespace plat = paddle::platform; -REGISTER_OP_CPU_KERNEL( - split, ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.cu.cc b/paddle/fluid/operators/split_op.cu.cc deleted file mode 100644 index a8a1383614bddb24b285734edb6f74e2789fdfeb..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/split_op.cu.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/split_op.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - split, ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index 96ac2c7a1bd086c2ca937d26160a6ac9316c92cc..0538fad08278e774889cc55cae9f5b72da0d27e3 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -19,10 +19,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" - +#include "paddle/pten/kernels/split_kernel.h" namespace paddle { namespace operators { static inline std::vector UpdateOutsDims( @@ -108,56 +106,6 @@ static inline std::vector UpdateOutsDims( } return outs_dims; } -template -class SplitOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - int num = ctx.Attr("num"); - std::vector sections = ctx.Attr>("sections"); - int axis = ctx.Attr("axis"); - - auto in_dims = in->dims(); - auto outs_number = outs.size(); - - bool need_resize_outs_dims = false; - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - need_resize_outs_dims = true; - } - auto sections_tensor_list = - ctx.MultiInput("SectionsTensorList"); - if (sections_tensor_list.size() > 0) { - sections = GetDataFromTensorList(sections_tensor_list); - need_resize_outs_dims = true; - } - - if (need_resize_outs_dims) { - std::vector outs_dims = - UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number); - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->Resize(outs_dims[j]); - } - } - - std::vector shape_refer; - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->mutable_data(ctx.GetPlace()); - shape_refer.emplace_back(outs[j]); - } - - auto& dev_ctx = ctx.template device_context(); - // Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && outs.size() < 10) { - StridedMemcpyWithAxis0(dev_ctx, *in, shape_refer, &outs); - } else { - math::SplitFunctor functor; - functor(dev_ctx, *in, shape_refer, axis, &outs); - } - } -}; template class SplitGradMaker : public framework::SingleGradOpMaker { diff --git a/paddle/fluid/platform/profiler/CMakeLists.txt b/paddle/fluid/platform/profiler/CMakeLists.txt index 72bf5134cc18d0a8841875e0fbd04b8c02bf8280..626847f04653cae1acec7dc06d594700aa5d1d70 100644 --- a/paddle/fluid/platform/profiler/CMakeLists.txt +++ b/paddle/fluid/platform/profiler/CMakeLists.txt @@ -1,5 +1,6 @@ cc_library(host_tracer SRCS host_tracer.cc DEPS enforce) -cc_library(new_profiler SRCS profiler.cc DEPS host_tracer) +cc_library(cuda_tracer SRCS cuda_tracer.cc cupti_data_process.cc DEPS workqueue_utils enforce glog) +cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer) cc_library(event_node SRCS event_node.cc DEPS enforce) cc_library(chrometracinglogger SRCS chrometracing_logger.cc DEPS event_node) cc_test(test_event_node SRCS test_event_node.cc DEPS event_node chrometracinglogger) diff --git a/paddle/fluid/platform/profiler/cuda_tracer.cc b/paddle/fluid/platform/profiler/cuda_tracer.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d3e354dc271a0241a9b63005aa29970d1548109 --- /dev/null +++ b/paddle/fluid/platform/profiler/cuda_tracer.cc @@ -0,0 +1,191 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/profiler/cuda_tracer.h" +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" +#include "paddle/fluid/platform/os_info.h" +#include "paddle/fluid/platform/profiler/cupti_data_process.h" + +#define CUPTI_CALL(call) \ + do { \ + CUptiResult _status = call; \ + if (_status != CUPTI_SUCCESS) { \ + const char* errstr; \ + dynload::cuptiGetResultString(_status, &errstr); \ + LOG(ERROR) << "Function " << #call << " failed with error " << errstr; \ + exit(-1); \ + } \ + } while (0) + +namespace paddle { +namespace platform { + +namespace details { +std::unordered_map CreateThreadIdMapping() { + std::unordered_map mapping; + std::unordered_map ids = GetAllThreadIds(); + for (const auto& id : ids) { + mapping[id.second.cupti_tid] = id.second.sys_tid; + } + return mapping; +} +} // namespace details + +CudaTracer::CudaTracer() {} + +void CudaTracer::PrepareTracing() { + PADDLE_ENFORCE_EQ( + state_ == TracerState::UNINITED || state_ == TracerState::STOPED, true, + platform::errors::PreconditionNotMet("Tracer must be UNINITED")); + EnableCuptiActivity(); + state_ = TracerState::READY; +} + +void CudaTracer::StartTracing() { + PADDLE_ENFORCE_EQ( + state_ == TracerState::READY, true, + platform::errors::PreconditionNotMet("Tracer must be READY or STOPPED")); + ConsumeBuffers(); + tracing_start_ns_ = PosixInNsec(); + state_ = TracerState::STARTED; +} + +void CudaTracer::StopTracing() { + PADDLE_ENFORCE_EQ( + state_, TracerState::STARTED, + platform::errors::PreconditionNotMet("Tracer must be STARTED")); + DisableCuptiActivity(); + state_ = TracerState::STOPED; +} + +void CudaTracer::CollectTraceData(TraceEventCollector* collector) { + PADDLE_ENFORCE_EQ( + state_, TracerState::STOPED, + platform::errors::PreconditionNotMet("Tracer must be STOPED")); + ProcessCuptiActivity(collector); +} + +int CudaTracer::ProcessCuptiActivity(TraceEventCollector* collector) { + int record_cnt = 0; +#ifdef PADDLE_WITH_CUPTI + CUPTI_CALL(dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED)); + auto mapping = details::CreateThreadIdMapping(); + std::vector buffers = ConsumeBuffers(); + for (auto& buffer : buffers) { + if (buffer.addr == nullptr || buffer.valid_size == 0) { + continue; + } + + CUpti_Activity* record = nullptr; + while (true) { + CUptiResult status = dynload::cuptiActivityGetNextRecord( + buffer.addr, buffer.valid_size, &record); + if (status == CUPTI_SUCCESS) { + details::ProcessCuptiActivityRecord(record, tracing_start_ns_, mapping, + collector); + ++record_cnt; + } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + CUPTI_CALL(status); + } + } + + ReleaseBuffer(buffer.addr); + } +#endif + return record_cnt; +} + +void CudaTracer::EnableCuptiActivity() { +#ifdef PADDLE_WITH_CUPTI + CUPTI_CALL(dynload::cuptiActivityRegisterCallbacks(BufferRequestedCallback, + BufferCompletedCallback)); + + CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY)); + CUPTI_CALL( + dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); + CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER)); + CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME)); + CUPTI_CALL(dynload::cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); + VLOG(3) << "enable cupti activity"; +#endif +} + +void CudaTracer::DisableCuptiActivity() { +#ifdef PADDLE_WITH_CUPTI + CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY)); + CUPTI_CALL( + dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); + CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER)); + CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME)); + CUPTI_CALL(dynload::cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); + VLOG(3) << "disable cupti activity"; +#endif +} + +#ifdef PADDLE_WITH_CUPTI +void CUPTIAPI CudaTracer::BufferRequestedCallback(uint8_t** buffer, + size_t* size, + size_t* max_num_records) { + GetInstance().AllocateBuffer(buffer, size); + *max_num_records = 0; +} + +void CUPTIAPI CudaTracer::BufferCompletedCallback(CUcontext ctx, + uint32_t stream_id, + uint8_t* buffer, size_t size, + size_t valid_size) { + GetInstance().ProduceBuffer(buffer, valid_size); + size_t dropped = 0; + CUPTI_CALL( + dynload::cuptiActivityGetNumDroppedRecords(ctx, stream_id, &dropped)); + if (dropped != 0) { + LOG(WARNING) << "Stream " << stream_id << " Dropped " << dropped + << " activity records"; + } +} +#endif + +void CudaTracer::AllocateBuffer(uint8_t** buffer, size_t* size) { + constexpr size_t kBufSize = 1 << 23; // 8 MB + constexpr size_t kBufAlign = 8; // 8 B + *buffer = reinterpret_cast( + paddle::framework::AlignedMalloc(kBufSize, kBufAlign)); + *size = kBufSize; +} + +void CudaTracer::ProduceBuffer(uint8_t* buffer, size_t valid_size) { + std::lock_guard guard(activity_buffer_lock_); + activity_buffers_.emplace_back(buffer, valid_size); +} + +std::vector CudaTracer::ConsumeBuffers() { + std::vector buffers; + { + std::lock_guard guard(activity_buffer_lock_); + buffers.swap(activity_buffers_); + } + return buffers; +} + +void CudaTracer::ReleaseBuffer(uint8_t* buffer) { + paddle::framework::AlignedFree(buffer); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/cuda_tracer.h b/paddle/fluid/platform/profiler/cuda_tracer.h new file mode 100644 index 0000000000000000000000000000000000000000..20a60521266a2b32e01508a59981956870ee09dc --- /dev/null +++ b/paddle/fluid/platform/profiler/cuda_tracer.h @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/platform/dynload/cupti.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/profiler/tracer_base.h" + +namespace paddle { +namespace platform { + +// Based on CUDA CUPTI +class CudaTracer : public TracerBase { + public: + // Singleton. CUPTI imposes this restriction. + static CudaTracer& GetInstance() { + static CudaTracer instance; + return instance; + } + + void PrepareTracing() override; + + void StartTracing() override; + + void StopTracing() override; + + void CollectTraceData(TraceEventCollector* collector) override; + + private: + struct ActivityBuffer { + ActivityBuffer(uint8_t* addr, size_t size) : addr(addr), valid_size(size) {} + uint8_t* addr; + size_t valid_size; + }; + + CudaTracer(); + + DISABLE_COPY_AND_ASSIGN(CudaTracer); + + void EnableCuptiActivity(); + + void DisableCuptiActivity(); + + int ProcessCuptiActivity(TraceEventCollector* collector); + +#ifdef PADDLE_WITH_CUPTI + // Used by CUPTI Activity API to request buffer + static void CUPTIAPI BufferRequestedCallback(uint8_t** buffer, size_t* size, + size_t* max_num_records); + + // Used by CUPTI Activity API to commit a completed buffer + static void CUPTIAPI BufferCompletedCallback(CUcontext ctx, + uint32_t stream_id, + uint8_t* buffer, size_t size, + size_t valid_size); +#endif + + void AllocateBuffer(uint8_t** buffer, size_t* size); + + void ProduceBuffer(uint8_t* buffer, size_t valid_size); + + std::vector ConsumeBuffers(); + + void ReleaseBuffer(uint8_t* buffer); + + uint64_t tracing_start_ns_ = UINT64_MAX; + std::mutex activity_buffer_lock_; + std::vector activity_buffers_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/cupti_data_process.cc b/paddle/fluid/platform/profiler/cupti_data_process.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d3b807aba82ea91770dddfcf655ec2431cdb197 --- /dev/null +++ b/paddle/fluid/platform/profiler/cupti_data_process.cc @@ -0,0 +1,304 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/profiler/cupti_data_process.h" +#include +#include "paddle/fluid/platform/os_info.h" + +namespace paddle { +namespace platform { +namespace details { +#ifdef PADDLE_WITH_CUPTI +void AddKernelRecord(const CUpti_ActivityKernel4* kernel, uint64_t start_ns, + TraceEventCollector* collector) { + if (kernel->start < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = kernel->name; + event.type = TracerEventType::Kernel; + event.start_ns = kernel->start; + event.end_ns = kernel->end; + event.device_id = kernel->deviceId; + event.context_id = kernel->contextId; + event.stream_id = kernel->streamId; + event.correlation_id = kernel->correlationId; + event.kernel_info.block_x = kernel->blockX; + event.kernel_info.block_y = kernel->blockY; + event.kernel_info.block_z = kernel->blockZ; + event.kernel_info.grid_x = kernel->gridX; + event.kernel_info.grid_y = kernel->gridY; + event.kernel_info.grid_z = kernel->gridZ; + event.kernel_info.dynamic_shared_memory = kernel->dynamicSharedMemory; + event.kernel_info.static_shared_memory = kernel->staticSharedMemory; + event.kernel_info.registers_per_thread = kernel->registersPerThread; + event.kernel_info.local_memory_per_thread = kernel->localMemoryPerThread; + event.kernel_info.local_memory_total = kernel->localMemoryTotal; + event.kernel_info.queued = kernel->queued; + event.kernel_info.submitted = kernel->submitted; + event.kernel_info.completed = kernel->completed; + collector->AddDeviceEvent(std::move(event)); +} + +const char* MemcpyKind(uint8_t kind) { + switch (kind) { + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD: + return "MEMCPY_HtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH: + return "MEMCPY_DtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA: + return "MEMCPY_HtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH: + return "MEMCPY_AtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA: + return "MEMCPY_AtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD: + return "MEMCPY_AtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA: + return "MEMCPY_DtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD: + return "MEMCPY_DtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH: + return "MEMCPY_HtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP: + return "MEMCPY_PtoP"; + default: + return "MEMCPY"; + } +} + +const char* MemoryKind(uint16_t kind) { + switch (kind) { + case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN: + return "Unknown"; + case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE: + return "Pageable"; + case CUPTI_ACTIVITY_MEMORY_KIND_PINNED: + return "Pinned"; + case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE: + return "Device"; + case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY: + return "Array"; + case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED: + return "Managed"; + case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE_STATIC: + return "Device Static"; + case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED_STATIC: + return "Managed Static"; + default: + return "Unknown"; + } +} + +void AddMemcpyRecord(const CUpti_ActivityMemcpy* memcpy, uint64_t start_ns, + TraceEventCollector* collector) { + if (memcpy->start < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = MemcpyKind(memcpy->copyKind); + event.type = TracerEventType::Memcpy; + event.start_ns = memcpy->start; + event.end_ns = memcpy->end; + event.device_id = memcpy->deviceId; + event.context_id = memcpy->contextId; + event.stream_id = memcpy->streamId; + event.correlation_id = memcpy->correlationId; + event.memcpy_info.num_bytes = memcpy->bytes; + // snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s", + // MemcpyKind(memcpy->copyKind)); + snprintf(event.memcpy_info.src_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy->srcKind)); + snprintf(event.memcpy_info.dst_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy->dstKind)); + collector->AddDeviceEvent(std::move(event)); +} + +void AddMemcpy2Record(const CUpti_ActivityMemcpy2* memcpy2, uint64_t start_ns, + TraceEventCollector* collector) { + if (memcpy2->start < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = MemcpyKind(memcpy2->copyKind); + event.type = TracerEventType::Memcpy; + event.start_ns = memcpy2->start; + event.end_ns = memcpy2->end; + event.device_id = memcpy2->deviceId; + event.context_id = memcpy2->contextId; + event.stream_id = memcpy2->streamId; + event.correlation_id = memcpy2->correlationId; + event.memcpy_info.num_bytes = memcpy2->bytes; + // snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s", + // MemcpyKind(memcpy2->copyKind)); + snprintf(event.memcpy_info.src_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy2->srcKind)); + snprintf(event.memcpy_info.dst_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy2->dstKind)); + collector->AddDeviceEvent(std::move(event)); +} + +void AddMemsetRecord(const CUpti_ActivityMemset* memset, uint64_t start_ns, + TraceEventCollector* collector) { + if (memset->start < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = "MEMSET"; + event.type = TracerEventType::Memset; + event.start_ns = memset->start; + event.end_ns = memset->end; + event.device_id = memset->deviceId; + event.context_id = memset->contextId; + event.stream_id = memset->streamId; + event.correlation_id = memset->correlationId; + event.memset_info.num_bytes = memset->bytes; + snprintf(event.memset_info.memory_kind, kMemKindMaxLen, "%s", + MemoryKind(memset->memoryKind)); + event.memset_info.value = memset->value; + collector->AddDeviceEvent(std::move(event)); +} + +class CuptiRuntimeCbidStr { + public: + static const CuptiRuntimeCbidStr& GetInstance() { + static CuptiRuntimeCbidStr inst; + return inst; + } + + std::string RuntimeKind(CUpti_CallbackId cbid) const { + auto iter = cbid_str_.find(cbid); + if (iter == cbid_str_.end()) { + return "Runtime API " + std::to_string(cbid); + } + return iter->second; + } + + private: + CuptiRuntimeCbidStr(); + + std::unordered_map cbid_str_; +}; + +CuptiRuntimeCbidStr::CuptiRuntimeCbidStr() { +#define REGISTER_RUNTIME_CBID_STR(cbid) \ + cbid_str_[CUPTI_RUNTIME_TRACE_CBID_##cbid] = #cbid + REGISTER_RUNTIME_CBID_STR(cudaBindTexture_v3020); + REGISTER_RUNTIME_CBID_STR(cudaConfigureCall_v3020); + REGISTER_RUNTIME_CBID_STR(cudaDeviceGetAttribute_v5000); + REGISTER_RUNTIME_CBID_STR(cudaDeviceGetStreamPriorityRange_v5050); + REGISTER_RUNTIME_CBID_STR(cudaDeviceSynchronize_v3020); + REGISTER_RUNTIME_CBID_STR(cudaDriverGetVersion_v3020); + REGISTER_RUNTIME_CBID_STR(cudaEventCreateWithFlags_v3020); + REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020); + REGISTER_RUNTIME_CBID_STR(cudaEventDestroy_v3020); + REGISTER_RUNTIME_CBID_STR(cudaEventQuery_v3020); + REGISTER_RUNTIME_CBID_STR(cudaEventRecord_v3020); + REGISTER_RUNTIME_CBID_STR(cudaFreeHost_v3020); + REGISTER_RUNTIME_CBID_STR(cudaFree_v3020); + REGISTER_RUNTIME_CBID_STR(cudaFuncGetAttributes_v3020); + REGISTER_RUNTIME_CBID_STR(cudaGetDeviceCount_v3020); + REGISTER_RUNTIME_CBID_STR(cudaGetDeviceProperties_v3020); + REGISTER_RUNTIME_CBID_STR(cudaGetDevice_v3020); + REGISTER_RUNTIME_CBID_STR(cudaGetErrorString_v3020); + REGISTER_RUNTIME_CBID_STR(cudaGetLastError_v3020); + REGISTER_RUNTIME_CBID_STR(cudaHostAlloc_v3020); + REGISTER_RUNTIME_CBID_STR(cudaHostGetDevicePointer_v3020); + REGISTER_RUNTIME_CBID_STR(cudaLaunchKernel_v7000); + REGISTER_RUNTIME_CBID_STR(cudaMallocHost_v3020); + REGISTER_RUNTIME_CBID_STR(cudaMalloc_v3020); + REGISTER_RUNTIME_CBID_STR(cudaMemcpyAsync_v3020); + REGISTER_RUNTIME_CBID_STR(cudaMemcpy_v3020); + REGISTER_RUNTIME_CBID_STR(cudaMemsetAsync_v3020); + REGISTER_RUNTIME_CBID_STR(cudaMemset_v3020); + REGISTER_RUNTIME_CBID_STR( + cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_v7000); + REGISTER_RUNTIME_CBID_STR(cudaPeekAtLastError_v3020); + REGISTER_RUNTIME_CBID_STR(cudaRuntimeGetVersion_v3020); + REGISTER_RUNTIME_CBID_STR(cudaSetDevice_v3020); + REGISTER_RUNTIME_CBID_STR(cudaStreamCreate_v3020); + REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithFlags_v5000); + REGISTER_RUNTIME_CBID_STR(cudaStreamCreateWithPriority_v5050); + REGISTER_RUNTIME_CBID_STR(cudaStreamDestroy_v5050); + REGISTER_RUNTIME_CBID_STR(cudaStreamSynchronize_v3020); + REGISTER_RUNTIME_CBID_STR(cudaStreamWaitEvent_v3020); + REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020); + REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020); + REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020); + REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010); +#if CUDA_VERSION >= 9000 + REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000); + REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000); +#endif +#undef REGISTER_RUNTIME_CBID_STR +} + +void AddApiRecord(const CUpti_ActivityAPI* api, uint64_t start_ns, + const std::unordered_map tid_mapping, + TraceEventCollector* collector) { + if (api->start < start_ns) { + return; + } + RuntimeTraceEvent event; + event.name = CuptiRuntimeCbidStr::GetInstance().RuntimeKind(api->cbid); + event.start_ns = api->start; + event.end_ns = api->end; + event.process_id = GetProcessId(); + uint64_t tid = 0; + auto iter = tid_mapping.find(api->threadId); + if (iter == tid_mapping.end()) { + } else { + tid = iter->second; + } + event.thread_id = tid; + event.correlation_id = api->correlationId; + event.callback_id = api->cbid; + collector->AddRuntimeEvent(std::move(event)); +} + +void ProcessCuptiActivityRecord( + const CUpti_Activity* record, uint64_t start_ns, + const std::unordered_map tid_mapping, + TraceEventCollector* collector) { + switch (record->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: + AddKernelRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CUPTI_ACTIVITY_KIND_MEMCPY: + AddMemcpyRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CUPTI_ACTIVITY_KIND_MEMCPY2: + AddMemcpy2Record(reinterpret_cast(record), + start_ns, collector); + break; + case CUPTI_ACTIVITY_KIND_MEMSET: + AddMemsetRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CUPTI_ACTIVITY_KIND_DRIVER: + case CUPTI_ACTIVITY_KIND_RUNTIME: + AddApiRecord(reinterpret_cast(record), start_ns, + tid_mapping, collector); + break; + default: + break; + } +} +#endif +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/cupti_data_process.h b/paddle/fluid/platform/profiler/cupti_data_process.h new file mode 100644 index 0000000000000000000000000000000000000000..01b2e72ade4e2e0d8061bad6cbcfa539a7dd8275 --- /dev/null +++ b/paddle/fluid/platform/profiler/cupti_data_process.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/dynload/cupti.h" +#include "paddle/fluid/platform/profiler/trace_event_collector.h" + +namespace paddle { +namespace platform { +namespace details { +#ifdef PADDLE_WITH_CUPTI +void ProcessCuptiActivityRecord( + const CUpti_Activity* record, uint64_t start_ns, + const std::unordered_map tid_mapping, + TraceEventCollector* collector); +#endif +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/host_event_recorder.h b/paddle/fluid/platform/profiler/host_event_recorder.h index 9c810dc184c00381a23f1f08da26acab8bbe0b3c..3bcd68c55963082bfc0ce12bbcdc0b07a05bbe97 100644 --- a/paddle/fluid/platform/profiler/host_event_recorder.h +++ b/paddle/fluid/platform/profiler/host_event_recorder.h @@ -1,16 +1,16 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/paddle/fluid/platform/profiler/host_tracer.cc b/paddle/fluid/platform/profiler/host_tracer.cc index 80f9a5d9af1e0a04be2074aabc44abf6af928fca..2172fe4d1e3d5786492ea8741b5e50146648e59d 100644 --- a/paddle/fluid/platform/profiler/host_tracer.cc +++ b/paddle/fluid/platform/profiler/host_tracer.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/platform/profiler/host_tracer.h" #include "glog/logging.h" diff --git a/paddle/fluid/platform/profiler/host_tracer.h b/paddle/fluid/platform/profiler/host_tracer.h index c73b5eca15f0e000eafea02360d8d94f2152192c..b6c10e558b787cd84e760fb892bd75ebace90c3c 100644 --- a/paddle/fluid/platform/profiler/host_tracer.h +++ b/paddle/fluid/platform/profiler/host_tracer.h @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index 96fa157f3995f19369460cdb3e2424bd59aefa37..5784d6e671bbbc69a7762e5a0e757310fc5e7a3b 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/platform/profiler/profiler.h" #include "glog/logging.h" @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #endif #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/profiler/cuda_tracer.h" #include "paddle/fluid/platform/profiler/host_tracer.h" #include "paddle/fluid/platform/profiler/trace_event_collector.h" @@ -46,6 +47,7 @@ Profiler::Profiler(const ProfilerOptions& options) { HostTracerOptions host_tracer_options; host_tracer_options.trace_level = options.trace_level; tracers_.emplace_back(new HostTracer(host_tracer_options), true); + tracers_.emplace_back(&CudaTracer::GetInstance(), false); } Profiler::~Profiler() { alive_.store(false); } diff --git a/paddle/fluid/platform/profiler/profiler.h b/paddle/fluid/platform/profiler/profiler.h index 33fc844b0f385796baac52a2ececf29bb77421bc..de5a0cc9be4ede29ac70409edaac5541c53c5c96 100644 --- a/paddle/fluid/platform/profiler/profiler.h +++ b/paddle/fluid/platform/profiler/profiler.h @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/paddle/fluid/platform/profiler/profiler_test.cc b/paddle/fluid/platform/profiler/profiler_test.cc index 6bd3ed9d8099b35f901df38b5775b1637e60e485..160c801dc6e3efa0a73ad132cc5509b03f7cffa8 100644 --- a/paddle/fluid/platform/profiler/profiler_test.cc +++ b/paddle/fluid/platform/profiler/profiler_test.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -44,10 +44,44 @@ TEST(ProfilerTest, TestHostTracer) { } auto nodetree = profiler->Stop(); std::set host_events; - for (const auto pair : nodetree->Traverse(true)) + for (const auto pair : nodetree->Traverse(true)) { for (const auto evt : pair.second) { host_events.insert(evt->Name()); } + } EXPECT_EQ(host_events.count("TestTraceLevel_record1"), 1u); EXPECT_EQ(host_events.count("TestTraceLevel_record2"), 0u); } + +TEST(ProfilerTest, TestCudaTracer) { + using paddle::platform::ProfilerOptions; + using paddle::platform::Profiler; + ProfilerOptions options; + options.trace_level = 0; + auto profiler = Profiler::Create(options); + EXPECT_TRUE(profiler); + profiler->Prepare(); + profiler->Start(); +#ifdef PADDLE_WITH_CUDA + cudaStream_t stream; + cudaStreamCreate(&stream); + cudaStreamSynchronize(stream); +#endif +#ifdef PADDLE_WITH_HIP + hipStream_t stream; + hipStreamCreate(&stream); + hipStreamSynchronize(stream); +#endif + auto nodetree = profiler->Stop(); + std::vector runtime_events; + for (const auto pair : nodetree->Traverse(true)) { + for (const auto host_node : pair.second) { + for (auto runtime_node : host_node->GetRuntimeTraceEventNodes()) { + runtime_events.push_back(runtime_node->Name()); + } + } + } +#ifdef PADDLE_WITH_CUPTI + EXPECT_GT(runtime_events.size(), 0u); +#endif +} diff --git a/paddle/fluid/platform/profiler/tracer_base.h b/paddle/fluid/platform/profiler/tracer_base.h index 1d4e3447fe64e4b395d1e48056e59195dc7d15c5..131159baff01bbd684de225917cdff7db7f5c2d1 100644 --- a/paddle/fluid/platform/profiler/tracer_base.h +++ b/paddle/fluid/platform/profiler/tracer_base.h @@ -1,16 +1,16 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 1a6bd9f35aa53ae0414e7ed5cb1eae3c34f3a856..85a39710564bc8c1b56a76035f7b2c56628ecf95 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -506,7 +506,7 @@ PyObject* ToPyObject(const paddle::framework::proto::VarType& type) { } PyObject* ToPyObject(const paddle::framework::LoDTensor* value) { - auto obj = ::pybind11::cast(value, py::return_value_policy::copy); + auto obj = ::pybind11::cast(value, py::return_value_policy::reference); obj.inc_ref(); return obj.ptr(); } diff --git a/paddle/pten/api/include/manual_api.h b/paddle/pten/api/include/manual_api.h index 3bd7e60154d06a6f5bf2381b8c62d14bc85c2212..942bbe970457211f7e74c95067c43b63a4059748 100644 --- a/paddle/pten/api/include/manual_api.h +++ b/paddle/pten/api/include/manual_api.h @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/pten/api/include/tensor.h" #include "paddle/pten/common/backend.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" /** * This file stores some special APIs that are implemented manually @@ -28,5 +30,11 @@ namespace experimental { // TODO(chenweihang): Replace backend by place when place is ready PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking); +// TODO(chentianyu03): Split API has extra logic to calculate the outputs size, +// api_gen do not support +PADDLE_API std::vector split(const Tensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis); + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/manual_api.cc b/paddle/pten/api/lib/manual_api.cc index 1af5150b4aed475884fc16e01c67feaf020dce53..667bd177ee1f6232f44299502e348e23579cb49c 100644 --- a/paddle/pten/api/lib/manual_api.cc +++ b/paddle/pten/api/lib/manual_api.cc @@ -19,9 +19,12 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/pten/api/lib/api_registry.h" +#include "paddle/pten/api/lib/api_utils.h" +#include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/infermeta/unary.h" PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); @@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) { return out; } +PADDLE_API std::vector split(const Tensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "split", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "split API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "split API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_x = PrepareData(x, kernel.InputAt(0), {}); + + // Calculate the number of out tensors + size_t out_number; + if (num_or_sections.GetData().size() == 1) { + out_number = num_or_sections.GetData()[0]; + } else { + out_number = num_or_sections.GetData().size(); + } + + std::vector out; + auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out); + std::vector meta_outs; + for (size_t i = 0; i < out_number; ++i) { + meta_outs.push_back(dense_outs[i]); + } + + pten::SplitInferMeta( + MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const pten::DenseTensor&, + const pten::ScalarArray&, + const pten::Scalar&, + std::vector&); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *dense_x, + pten::ScalarArray(num_or_sections), + pten::Scalar(axis), + dense_outs); + + return out; +} } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index c0d72452501c3d9cc659710002573655ab2458a8..230787c1b35cdc5bfe2d44ee5757b301404ea871 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -36,45 +36,6 @@ std::unique_ptr MakePtenDenseTensor( return std::make_unique(src); } -pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src) { - PADDLE_ENFORCE_EQ(src.numel(), - 1, - paddle::platform::errors::InvalidArgument( - "The Scalar only supports Tensor with 1 element, " - "but now Tensor has %d element.", - src.numel())); - switch (src.type()) { - case paddle::framework::proto::VarType::FP32: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::FP64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::FP16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::BF16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT32: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT8: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::UINT8: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::BOOL: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::COMPLEX64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::COMPLEX128: - return {src.template data()[0]}; - default: - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. Don't support casting a %d LoDTensor to Scalar.", - src.type())); - } -} - pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU); if (variable.IsType()) { @@ -82,9 +43,9 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { if (!platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); - return MakePtenScalar(tmp_tensor); + return {tmp_tensor}; } else { - return MakePtenScalar(tensor); + return {tensor}; } } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -95,17 +56,7 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { } pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src) { - if (src.type() == paddle::framework::proto::VarType::INT64) { - return {src.data(), src.numel()}; - } else if (src.type() == paddle::framework::proto::VarType::INT32) { - return {src.data(), src.numel()}; - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. When cast a LoDTensor to ScalarArray, " - "the data type of LoDTensor must be int32 or int64, " - "but now data type is %s.", - src.type())); - } + return {src}; } pten::ScalarArray MakePtenScalarArrayFromVar( @@ -128,6 +79,7 @@ pten::ScalarArray MakePtenScalarArrayFromVar( } } +// TODO(chentianyu03): Inplace with ScalarArray constructor pten::ScalarArray MakePtenScalarArrayFromVarList( const std::vector& variable_list) { if (variable_list.size() == 0) { @@ -135,45 +87,28 @@ pten::ScalarArray MakePtenScalarArrayFromVarList( } auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU); - paddle::framework::proto::VarType::Type data_type; - auto* first_var = variable_list.front(); - if (first_var->IsType()) { - const auto& tensor = first_var->Get(); - data_type = tensor.type(); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(first_var->Type()))); - } - std::vector vector_data; vector_data.reserve(variable_list.size()); - if (data_type == paddle::framework::proto::VarType::INT64) { - for (auto* var : variable_list) { - if (var->IsType()) { + for (auto* var : variable_list) { + paddle::framework::proto::VarType::Type data_type; + if (var->IsType()) { + const auto& tensor = var->Get(); + data_type = tensor.type(); + if (data_type == paddle::framework::proto::VarType::INT64) { const auto& tensor = var->Get(); - if (!platform::is_same_place(tensor.place(), expected_place)) { + if (tensor.IsInitialized() && + !platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); vector_data.push_back(*tmp_tensor.data()); } else { vector_data.push_back(*tensor.data()); } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(var->Type()))); - } - } - - } else if (data_type == paddle::framework::proto::VarType::INT32) { - for (auto* var : variable_list) { - if (var->IsType()) { + } else if (data_type == paddle::framework::proto::VarType::INT32) { const auto& tensor = var->Get(); - if (!platform::is_same_place(tensor.place(), expected_place)) { + if (tensor.IsInitialized() && + !platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); vector_data.push_back(*tmp_tensor.data()); @@ -181,21 +116,24 @@ pten::ScalarArray MakePtenScalarArrayFromVarList( vector_data.push_back(*tensor.data()); } } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(var->Type()))); + PADDLE_THROW(pten::errors::InvalidArgument( + "Data type error. When cast a LoDTensor to VectorTensor, " + "the data type of LoDTensor must be int32 or int64, " + "but now data type is %s.", + data_type)); } + } else { + PADDLE_THROW(pten::errors::Unimplemented( + "Unsupport casting input `%s` type to VectorTensor when call pt " + "kernel.", + framework::ToTypeName(var->Type()))); } - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. When cast a LoDTensor to VectorTensor, " - "the data type of LoDTensor must be int32 or int64, " - "but now data type is %s.", - data_type)); } - return {vector_data}; + pten::ScalarArray result{vector_data}; + result.setInitByTensor(true); + + return result; } void ResetTensorDtypeAndLayoutByArgDef(pten::TensorBase* dst, diff --git a/paddle/pten/api/lib/utils/tensor_utils.h b/paddle/pten/api/lib/utils/tensor_utils.h index 1e2d8b74db84941f970c0613fad4fa488f813053..cf1daf732ee96711fde3e5910899d972dd748ceb 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.h +++ b/paddle/pten/api/lib/utils/tensor_utils.h @@ -33,8 +33,6 @@ namespace experimental { std::unique_ptr MakePtenDenseTensor( const paddle::framework::Tensor& src); -pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src); - pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src); pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable); diff --git a/paddle/pten/common/scalar.h b/paddle/pten/common/scalar.h index 5c8fb04633088a0f9bc53877e1ab7bddf1f073ad..0ab880d6218f8778d479f166ab26db2c651ab6ac 100644 --- a/paddle/pten/common/scalar.h +++ b/paddle/pten/common/scalar.h @@ -25,6 +25,7 @@ namespace experimental { template class ScalarBase { public: + bool IsInitByTensor() const { return is_init_by_tensor_; } // Constructor support implicit ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT data_.f64 = val; @@ -103,6 +104,7 @@ class ScalarBase { // The Tensor must have one dim ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT + is_init_by_tensor_ = true; PD_CHECK( tensor.numel() == 1, "The Scalar only supports Tensor with 1 element, but now Tensor has `", @@ -194,6 +196,7 @@ class ScalarBase { friend void CopyScalar(const ScalarBase& src, ScalarBase* dst); private: + bool is_init_by_tensor_{false}; DataType dtype_; union data { bool b; diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h index 81013d8e5a11cdd6b44587bb2151b7be18895c27..dcc8ff6748b869ac25c882e66f1ff77bc94534a2 100644 --- a/paddle/pten/common/scalar_array.h +++ b/paddle/pten/common/scalar_array.h @@ -43,8 +43,13 @@ class ScalarArrayBase { AssignData(date_value, n); } + bool IsInitByTensor() const { return is_init_by_tensor_; } + + void setInitByTensor(bool val) { is_init_by_tensor_ = val; } + // The Tensor must have one dim ScalarArrayBase(const T& tensor) { // NOLINT + is_init_by_tensor_ = true; size_t n = tensor.numel(); array_.reserve(n); switch (tensor.dtype()) { @@ -66,41 +71,17 @@ class ScalarArrayBase { // The Tensor in vec must have only one element ScalarArrayBase(const std::vector& tensor_list) { // NOLINT - auto n = tensor_list.size(); - array_.reserve(n); - if (!tensor_list.empty()) { - DataType data_type = tensor_list[0].dtype(); + is_init_by_tensor_ = true; + + for (size_t i = 0; i < tensor_list.size(); ++i) { + DataType data_type = tensor_list[i].dtype(); switch (data_type) { - case DataType::INT32: { - for (size_t i = 0; i < n; ++i) { - PD_CHECK(tensor_list[i].dtype() == data_type, - "The data_type of tensors in the list isn't consistent." - "the first tensor is`", - data_type, - "` but `", - i, - "`th tensor is`", - tensor_list[i].dtype(), - "`."); - array_.push_back(*tensor_list[i].template data()); - } + case DataType::INT32: + array_.push_back(*tensor_list[i].template data()); break; - } - case DataType::INT64: { - for (size_t i = 0; i < n; ++i) { - PD_CHECK(tensor_list[i].dtype() == data_type, - "The data_type of tensors in the list isn't consistent." - "the first tensor is`", - data_type, - "` but `", - i, - "`th tensor is`", - tensor_list[i].dtype(), - "`."); - array_.push_back(*tensor_list[i].template data()); - } + case DataType::INT64: + array_.push_back(*tensor_list[i].template data()); break; - } default: PD_THROW( "Data type error. Currently, The data type of ScalarArrayBase " @@ -136,6 +117,7 @@ class ScalarArrayBase { // TODO(zhangyunfei) Replace std::vector with a more efficient container // structure. std::vector array_; + bool is_init_by_tensor_{false}; }; using ScalarArray = diff --git a/paddle/pten/core/compat/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h index 42ab0f1fcc2bf3a19c67bce4e0475c8ee2bb3966..c2c2b0a518d6c30b441a55ecaeb88811f9187ea8 100644 --- a/paddle/pten/core/compat/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -77,6 +77,7 @@ class ArgumentMappingContext { virtual bool HasInput(const std::string& name) const = 0; virtual bool HasOutput(const std::string& name) const = 0; + virtual bool HasAttr(const std::string& name) const = 0; // now we can't use Attribute here, it will cause pten relay on // boost::variant and BlockDesc diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index c95ae6b69f73b528c5f784d01fa2100f48971b46..6de91db9382e22537e577ce3188764034c7235e3 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -146,6 +146,7 @@ struct InferMetaFnImpl { } }; + // TODO(chenweihang): support other attr type later PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); diff --git a/paddle/pten/infermeta/backward.cc b/paddle/pten/infermeta/backward.cc index b7bb17bdd1c38b1d2616bf31c52ecfdbd8626b55..db92449519436024a01c9c891f9671756777a345 100644 --- a/paddle/pten/infermeta/backward.cc +++ b/paddle/pten/infermeta/backward.cc @@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x, bool transpose_y, MetaTensor* dx, MetaTensor* dy) { - dx->share_meta(x); - dy->share_meta(y); + if (dx) { + dx->share_meta(x); + } + if (dy) { + dy->share_meta(y); + } } } // namespace pten diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 5f3b0712b5863145f7340dfb6ae34a6809d9d635..ca59937399a226558c213fed5b43a2311a2f368a 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -315,4 +315,137 @@ void TransferLayoutInferMeta(const MetaTensor& x, out->set_layout(layout); } +void SplitInferMeta(const MetaTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector* out, + MetaConfig config) { + int axis_value = axis.to(); + int rank = x.dims().size(); + PADDLE_ENFORCE_EQ( + axis_value >= -rank && axis_value < rank, + true, + paddle::platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, + rank, + axis_value)); + if (axis_value < 0) { + axis_value = axis_value + rank; + } + + auto input_axis_dim = x.dims().at(axis_value); + auto num_or_sections_data = num_or_sections.GetData(); + // step1: get formated sections + std::vector sections; + // num_or_sections is a number + if (num_or_sections_data.size() == 1) { + int num = num_or_sections_data.at(0); + + PADDLE_ENFORCE_EQ(input_axis_dim % num, + 0, + paddle::platform::errors::InvalidArgument( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But received Attr(num_or_sections) " + "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", + num, + x.dims(), + axis_value)); + + for (int i = 0; i < num; ++i) { + sections.push_back(input_axis_dim / num); + } + } else { + // num_or_sections is a sections + const int unknow_dim_val = -1; + int unknow_dim_idx = -1; + int num_of_unknow = 0; + int sum_of_section = 0; + + for (size_t i = 0; i < num_or_sections_data.size(); ++i) { + sections.push_back(num_or_sections_data[i]); + + if (num_or_sections_data[i] == unknow_dim_val) { + num_of_unknow++; + unknow_dim_idx = i; + } else { + sum_of_section += num_or_sections_data[i]; + } + } + + if (config.is_runtime) { + PADDLE_ENFORCE_LE(num_of_unknow, + 1, + paddle::platform::errors::InvalidArgument( + "Only one dimension value of Attr(num_or_sections) " + "in SplitOp can be -1. " + "But received Attr(num_or_sections) = [%s].", + pten::framework::make_ddim(num_or_sections_data))); + } + + if (unknow_dim_idx != -1) { + // for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1]. + // input_axis_dim = 5, sum_of_sections = 5. + // the following check will fail. + PADDLE_ENFORCE_LT( + sum_of_section, + input_axis_dim, + paddle::platform::errors::InvalidArgument( + "Sum of Attr(num_or_sections) other than unknown section " + "must be less than the input's " + "size " + "along the split dimension. But received Attr(num_or_sections) " + "= [%s], input(X)'s shape = [%s], Attr(dim) = %d.", + pten::framework::make_ddim(num_or_sections_data), + x.dims(), + axis_value)); + + if (config.is_runtime) { + sections[unknow_dim_idx] = input_axis_dim - sum_of_section; + } + } else { + PADDLE_ENFORCE_EQ( + sum_of_section, + input_axis_dim, + paddle::platform::errors::InvalidArgument( + "Sum of Attr(num_or_sections) must be equal to the input's " + "size " + "along the split dimension. But received Attr(num_or_sections)" + " = [%s], input(X)'s shape = [%s], Attr(dim) = %d.", + pten::framework::make_ddim(num_or_sections_data), + x.dims(), + axis_value)); + } + } + + // setp2: fill out dims + std::vector out_dims(sections.size(), x.dims()); + if (config.is_runtime || input_axis_dim > 0) { + for (size_t i = 0; i < sections.size(); ++i) { + out_dims[i][axis_value] = sections[i]; + } + } else { + for (size_t i = 0; i < sections.size(); ++i) { + out_dims[i][axis_value] = -1; + } + } + + for (size_t i = 0; i < sections.size(); ++i) { + if (axis_value != 0) { + // Only pass LoD when not spliting along the first dim. + (*out)[i].set_dtype(x.dtype()); + (*out)[i].set_dims(out_dims[i]); + (*out)[i].set_layout(x.layout()); + } else { + (*out)[i].set_dtype(x.dtype()); + (*out)[i].set_dims(out_dims[i]); + (*out)[i].set_layout(x.layout()); + (*out)[i].share_lod(x); + } + } + + return; +} + } // namespace pten diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index f1dc806b4e9caee980f0bd4b9d5085f375c55bd2..4c816c4adbc233e0442c2100f62ee8e62cc8f78c 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/meta_tensor.h" @@ -74,4 +75,9 @@ void TransferLayoutInferMeta(const MetaTensor& x, DataLayout layout, MetaTensor* out); +void SplitInferMeta(const MetaTensor& x_meta, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector* out, + MetaConfig config = MetaConfig()); } // namespace pten diff --git a/paddle/pten/kernels/cpu/split_kernel.cc b/paddle/pten/kernels/cpu/split_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..78fcdcb155cf23b146c1a44ccc9651b6506d1d4d --- /dev/null +++ b/paddle/pten/kernels/cpu/split_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/pten/common/float16.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/cpu/concat_and_split.h" +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis_scalar, + std::vector outs) { + // need to infershape output + if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + std::vector out_metas; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + } + + pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.Alloc(outs[j]); + shape_refer.emplace_back(outs[j]); + } + + int axis = axis_scalar.to(); + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, x, shape_refer, &outs); + } else { + SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(split, + CPU, + ALL_LAYOUT, + pten::SplitKernel, + float, + double, + int64_t, + int, + bool, + pten::dtype::float16) {} diff --git a/paddle/pten/kernels/gpu/concat_and_split.h b/paddle/pten/kernels/gpu/concat_and_split.h index 47022666564df9ba5626f00ef15feccfd3e900d1..17b54bbbfdc549adfd06194e7851506341638879 100644 --- a/paddle/pten/kernels/gpu/concat_and_split.h +++ b/paddle/pten/kernels/gpu/concat_and_split.h @@ -134,12 +134,12 @@ __global__ void ConcatKernel_(const T** inputs_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t* out_cols, - int out_cols_size, - T** outputs_data) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t* out_cols, + int out_cols_size, + T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; int curr_offset = out_cols[0]; @@ -184,21 +184,21 @@ __device__ void SplitKernelDetail(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T** outputs_data) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T** outputs_data) { SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1) { T* outputs_data[2]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -206,13 +206,13 @@ __global__ void SplitKernel(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2) { T* outputs_data[3]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -221,14 +221,14 @@ __global__ void SplitKernel(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2, - T* outputs_addr3) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2, + T* outputs_addr3) { T* outputs_data[4]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -497,7 +497,7 @@ void SplitImpl(const Context& context, if (has_same_shape) { if (o_num == 2) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -505,7 +505,7 @@ void SplitImpl(const Context& context, outputs_data[0], outputs_data[1]); } else if (o_num == 3) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -514,7 +514,7 @@ void SplitImpl(const Context& context, outputs_data[1], outputs_data[2]); } else if (o_num == 4) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -524,7 +524,7 @@ void SplitImpl(const Context& context, outputs_data[2], outputs_data[3]); } else { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, out0_col, dev_out_gpu_data); } } else { @@ -542,7 +542,7 @@ void SplitImpl(const Context& context, int64_t* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, diff --git a/paddle/pten/kernels/gpu/split_kernel.cu b/paddle/pten/kernels/gpu/split_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..46d18b75b611b7acad8155b0d6bbed7b015a23c9 --- /dev/null +++ b/paddle/pten/kernels/gpu/split_kernel.cu @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/pten/common/float16.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/gpu/concat_and_split.h" +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis_scalar, + std::vector outs) { + // need to infershape output + if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + std::vector out_metas; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + } + + pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.Alloc(outs[j]); + shape_refer.emplace_back(outs[j]); + } + + int axis = axis_scalar.to(); + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, x, shape_refer, &outs); + } else { + SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(split, + GPU, + ALL_LAYOUT, + pten::SplitKernel, + float, + double, + int64_t, + int, + bool, + pten::dtype::float16, + pten::dtype::bfloat16) {} diff --git a/paddle/pten/kernels/split_kernel.h b/paddle/pten/kernels/split_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..30ac4da7a4ca04f519c58bb9ac205388a4448f5a --- /dev/null +++ b/paddle/pten/kernels/split_kernel.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/empty_kernel.h" + +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector out); + +template +std::vector Split(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis) { + size_t out_number; + if (num_or_sections.GetData().size() == 1) { + out_number = num_or_sections.GetData()[0]; + } else { + out_number = num_or_sections.GetData().size(); + } + + std::vector out_meta; + out_meta.reserve(out_number); + std::vector result; + result.reserve(out_number); + + for (size_t i = 0; i < out_number; ++i) { + auto dense_out = pten::Empty(dev_ctx); + MetaTensor tmp_meta(&dense_out); + + result.push_back(dense_out); + out_meta.push_back(&result.back()); + } + SplitInferMeta(x, num_or_sections, axis, &out_meta); + + std::vector outs; + outs.reserve(out_meta.size()); + for (size_t i = 0; i < out_meta.size(); ++i) { + outs.push_back(&result[i]); + } + + SplitKernel(dev_ctx, x, num_or_sections, axis, outs); + + return result; +} + +} // namespace pten diff --git a/paddle/pten/ops/compat/matmul_sig.cc b/paddle/pten/ops/compat/matmul_sig.cc index 963d5d6656b04aa94181bdc07ab7b0cf4d92de57..7f1f2cf437a4654881b0d28bf968e5fa5cab9783 100644 --- a/paddle/pten/ops/compat/matmul_sig.cc +++ b/paddle/pten/ops/compat/matmul_sig.cc @@ -17,10 +17,17 @@ limitations under the License. */ namespace pten { KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("matmul_grad", - {"X", "Y", GradVarName("Out")}, - {"trans_x", "trans_y"}, - {GradVarName("X"), GradVarName("Y")}); + if (ctx.HasAttr("use_addto")) { + return KernelSignature("addto_matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y", "use_addto"}, + {GradVarName("X"), GradVarName("Y")}); + } else { + return KernelSignature("matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y"}, + {GradVarName("X"), GradVarName("Y")}); + } } KernelSignature MatmulDoubleGradOpArgumentMapping( diff --git a/paddle/pten/ops/compat/split_sig.cc b/paddle/pten/ops/compat/split_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec58af5e9e41d3f89e9422b9b0f5bd650ee8347d --- /dev/null +++ b/paddle/pten/ops/compat/split_sig.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature SplitOpArgumentMapping(const ArgumentMappingContext& ctx) { + // priority: num > SectionsTensorList > sections + // priority: AxisTensor > axis + if (paddle::any_cast(ctx.Attr("num")) > 0) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("split", {"X"}, {"num", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature("split", {"X"}, {"num", "axis"}, {"Out"}); + } + } + + if (ctx.InputSize("SectionsTensorList") > 0) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature( + "split", {"X"}, {"SectionsTensorList", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature( + "split", {"X"}, {"SectionsTensorList", "axis"}, {"Out"}); + } + } + + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("split", {"X"}, {"sections", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature("split", {"X"}, {"sections", "axis"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(split, pten::SplitOpArgumentMapping); diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index b8491ab7f5ea89c0aaaf72ae9ae55ba7ea435083..d875dbd4444ae664472663caa0ea5b2694ca8e4f 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -22,6 +22,6 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils) - +cc_test(test_split_api SRCS test_split_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_data_transform SRCS test_data_transform.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_split_api.cc b/paddle/pten/tests/api/test_split_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac139832aa0082ae23d1ebee05d68cb360690241 --- /dev/null +++ b/paddle/pten/tests/api/test_split_api.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/include/manual_api.h" +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace paddle { +namespace tests { + +namespace framework = paddle::framework; +using DDim = pten::framework::DDim; + +// TODO(chentianyu03): Remove this test after the API is used in the dygraph +TEST(API, split) { + // 1. create tensor + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({4, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = + dense_x->mutable_data(paddle::platform::CPUPlace()); + + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + paddle::experimental::Tensor x(dense_x); + + // 2. test API + auto out = paddle::experimental::split(x, {2, 2}, 0); + + // 3. check result + ASSERT_EQ(out.size(), static_cast(2)); + ASSERT_EQ(out[0].dims().size(), 2); + ASSERT_EQ(out[0].dims()[0], 2); + ASSERT_EQ(out[0].dims()[1], 10); + ASSERT_EQ(out[0].type(), pten::DataType::FLOAT32); + ASSERT_EQ(out[0].layout(), pten::DataLayout::NCHW); + + ASSERT_EQ(out[1].dims().size(), 2); + ASSERT_EQ(out[1].dims()[0], 2); + ASSERT_EQ(out[1].dims()[1], 10); + ASSERT_EQ(out[1].type(), pten::DataType::FLOAT32); + ASSERT_EQ(out[1].layout(), pten::DataLayout::NCHW); + + auto out_data_0 = std::dynamic_pointer_cast(out[0].impl()) + ->data(); + auto out_data_1 = std::dynamic_pointer_cast(out[1].impl()) + ->data(); + for (size_t i = 0; i < 4; ++i) { + if (i < 20) { + ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6); + } else { + ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6); + } + } +} + +} // namespace tests +} // namespace paddle diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index e2063241689f929e6d173bcb29dde849ca5a3f48..15a1cab5f0dd473498ebb23e564ce88400af9713 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -11,4 +11,5 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils) cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_split_dev_api.cc b/paddle/pten/tests/kernels/test_split_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4e3619e11a3ad976938b3c78b303cbdcd04ce1b --- /dev/null +++ b/paddle/pten/tests/kernels/test_split_dev_api.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/pten/api/include/manual_api.h" +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = pten::framework::DDim; + +TEST(DEV_API, split) { + // 1. create tensor + const auto alloc = std::make_unique( + pten::CPUPlace()); + pten::DenseTensor dense_x( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({4, 10}), + pten::DataLayout::NCHW)); + pten::CPUContext dev_ctx; + dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx.Init(); + + auto* dense_x_data = dev_ctx.Alloc(&dense_x); + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + // 2. test API + auto out = pten::Split(dev_ctx, dense_x, {2, 2}, 0); + + // 3. check result + ASSERT_EQ(out.size(), static_cast(2)); + ASSERT_EQ(out[0].dims().size(), 2); + ASSERT_EQ(out[0].dims()[0], 2); + ASSERT_EQ(out[0].dims()[1], 10); + ASSERT_EQ(out[0].meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out[0].meta().layout, pten::DataLayout::NCHW); + + ASSERT_EQ(out[1].dims().size(), 2); + ASSERT_EQ(out[1].dims()[0], 2); + ASSERT_EQ(out[1].dims()[1], 10); + ASSERT_EQ(out[1].meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out[1].meta().layout, pten::DataLayout::NCHW); + + auto out_data_0 = out[0].data(); + auto out_data_1 = out[1].data(); + for (size_t i = 0; i < 4; ++i) { + if (i < 20) { + ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6); + } else { + ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6); + } + } +} + +} // namespace tests +} // namespace pten diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index af78b6d21eb8cf89b8964d4835ba815ec783c71b..7f2ad893f67a34f2bd4772614ff8f94f33f3bdb8 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -1759,11 +1759,11 @@ set +x set -x ut_endTime_s=`date +%s` echo "XPU testCase Time: $[ $ut_endTime_s - $ut_startTime_s ]s" + python ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py + unset XPU_OP_LIST_DIR if [[ "$EXIT_CODE" != "0" ]]; then exit 8; fi - python ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py - unset XPU_OP_LIST_DIR fi } diff --git a/python/paddle/distributed/passes/ps_trainer_pass.py b/python/paddle/distributed/passes/ps_trainer_pass.py index 28cfa873b96da0370b7814a926163e2ce0c1c7b0..fff10a2d4684afe51295cc460f8dc3424d13c4f5 100755 --- a/python/paddle/distributed/passes/ps_trainer_pass.py +++ b/python/paddle/distributed/passes/ps_trainer_pass.py @@ -21,25 +21,6 @@ from .pass_base import PassBase, register_pass from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.transpiler.collective import SingleProcessMultiThread -OP_NAME_SCOPE = "op_namescope" -CLIP_OP_NAME_SCOPE = "gradient_clip" -STEP_COUNTER = "@PS_STEP_COUNTER@" -OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() -RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC -LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched -OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize -op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() -backward = core.op_proto_and_checker_maker.OpRole.Backward - -SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} -SPARSE_GRAD_OP_TYPE_DICT = { - "lookup_table_grad": "W", - "lookup_table_v2_grad": "W" -} -DEVICE_LIST = ["cpu", "gpu", "xpu"] -COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"] -DEFAULT_DEVICE = 'cpu' - @register_pass("append_send_ops_pass") class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 @@ -894,6 +875,100 @@ class SplitTrainerOpsPass(PassBase): def _check_conflict(self, other_pass): return True + def _replace_ops_by_communicate_op(self, program, attrs, heter_block_index, + ops_list, block_var_detail): + all_op = program.global_block().ops + start_op = ops_list[0] + first_op_idx = -1 + for op in all_op: + if str(op) == str(start_op): + first_op_idx = all_op.index(op) + break + assert first_op_idx != -1 + self._delete_same_ops(program.global_block(), ops_list) + + entrance_var = [] + role_maker = attrs['role_maker'] + if heter_block_index == 1: + next_heter_worker_endpoints = get_next_stage_trainers(role_maker) + + entrance_var = block_var_detail[heter_block_index]["forward"][ + "entrance"] + + comm_info = get_communicate_var_info(program, heter_block_index + 1, + entrance_var) + program.global_block()._insert_op( + index=first_op_idx, + type="send_and_recv", + inputs={"X": program.global_block().vars[entrance_var[0]]}, + outputs={"Out": []}, + attrs={ + "mode": "forward", + "send_var_name": entrance_var + ["microbatch_id"], + "recv_var_name": [], + "message_name": comm_info["block_input_var_name"], + "next_endpoints": next_heter_worker_endpoints, + "previous_endpoints": [], + "trainer_id": get_role_id(role_maker), + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + return entrance_var + + def _delete_same_ops(self, block, ops): + for op in ops: + try: + for origin_op in block.ops: + if str(origin_op) == str(op): + idx = list(block.ops).index(origin_op) + block._remove_op(idx) + break + except Exception as e: + print(e) + + def _remove_var_pair_by_grad(self, var_name, attrs): + for index, pair in enumerate(attrs['merged_variables_pairs']): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del attrs['merged_variables_pairs'][index] + + for index, pair in enumerate(attrs['merged_dense_pairs']): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del attrs['merged_dense_pairs'][index] + return + + for index, pair in enumerate(attrs['merged_sparse_pairs']): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del attrs['merged_sparse_pairs'][index] + return + + def _remove_trainer_send_op(self, program, attrs, heter_block_index, + block_var_detail): + # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD + # if trainer only do SEND, it has one var: var@GRAD + # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD) + persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \ + block_var_detail[heter_block_index]["backward"]["persistables"] + need_remove_send_op = [] + need_remove_grad_var = [] + for op in find_send_op(program): + input_list, _ = find_op_input_output(program, + program.global_block(), op) + for var_name in input_list: + origin_var_name = var_name.split("@GRAD")[0] + if origin_var_name in persistables: + need_remove_send_op.append(op) + need_remove_grad_var.append(var_name) + need_remove_send_op = list(set(need_remove_send_op)) + delete_ops(program.global_block(), need_remove_send_op) + for grad_var_name in need_remove_grad_var: + self._remove_var_pair_by_grad(grad_var_name, attrs) + def _create_trainer_program(self, program, origin_program, attrs, program_block_ops_list, block_var_detail): # This function mainly includes the following contents: @@ -911,18 +986,18 @@ class SplitTrainerOpsPass(PassBase): ops_list = program_block_ops_list[heter_block_index][ "forward"] + program_block_ops_list[heter_block_index][ "backward"] - static_var += replace_ops_by_communicate_op( + static_var += self._replace_ops_by_communicate_op( program, attrs, heter_block_index, ops_list, block_var_detail) - remove_trainer_send_op(program, attrs, heter_block_index, - block_var_detail) + self._remove_trainer_send_op(program, attrs, heter_block_index, + block_var_detail) optimizer_block = [] grad_to_block_id = [] bp_ops_list = program_block_ops_list[0]["backward"] - delete_same_ops(program.global_block(), bp_ops_list) - delete_trainer_useless_var(attrs, program, static_var) - backward_block = create_backward_block(program, origin_program, attrs, + self._delete_same_ops(program.global_block(), bp_ops_list) + delete_trainer_useless_var(program, static_var) + backward_block = create_backward_block(program, origin_program, bp_ops_list, block_var_detail) bp_entrance_vars = block_var_detail[0]["backward"]["entrance"] diff --git a/python/paddle/distributed/ps/utils/ps_program_builder.py b/python/paddle/distributed/ps/utils/ps_program_builder.py index d978adaaba05b637f150d7499aa8fe7f48128715..d649a74e4d621bbc531ce194242fbbd07b01209a 100755 --- a/python/paddle/distributed/ps/utils/ps_program_builder.py +++ b/python/paddle/distributed/ps/utils/ps_program_builder.py @@ -186,10 +186,10 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): add_lr_decay_table_pass.apply([], [], self.pass_ctx) distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs) - distributed_ops_pass.apply([self.cloned_main], [], self.pass_ctx) + distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx) delete_optimizer_pass = new_pass("delete_optimizer_pass", self.attrs) - delete_optimizer_pass.apply([None], [_startup], self.pass_ctx) + delete_optimizer_pass.apply([self.cloned_main], [None], self.pass_ctx) append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs) append_send_ops_pass.apply([self.cloned_main], [None], self.pass_ctx) @@ -210,12 +210,13 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): else: split_trainer_ops_pass = new_pass("split_trainer_ops_pass", self.attrs) - split_trainer_ops_pass([self.cloned_main], [], self.pass_ctx) + split_trainer_ops_pass.apply([self.cloned_main], [None], + self.pass_ctx) set_heter_pipeline_opt_pass = new_pass('set_heter_pipeline_opt_pass', self.attrs) set_heter_pipeline_opt_pass.apply([self.cloned_main], - [self.cloned_startup], pass_ctx) + [self.cloned_startup], self.pass_ctx) if self.launch_barrier and self.launch_barrier_flag: wait_server_ready(server_endpoints) @@ -228,7 +229,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): ps_set_heter_pipeline_opt_pass = new_pass( "set_heter_pipeline_opt_pass", self.attrs) ps_set_heter_pipeline_opt_pass.apply( - [self.loss.block.program], [startup_program], self.pass_ctx) + [self.cloned_main], [self.cloned_startup], self.pass_ctx) elif self.attrs['is_server']: self._build_pserver_programs() diff --git a/python/paddle/distributed/ps/utils/public.py b/python/paddle/distributed/ps/utils/public.py index 3c883a0158adafed7753c9dd72a74fd438592ac4..a8587874776bb5f5586dd23ed32c1ee810ad97c0 100755 --- a/python/paddle/distributed/ps/utils/public.py +++ b/python/paddle/distributed/ps/utils/public.py @@ -42,9 +42,17 @@ RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize +backward = core.op_proto_and_checker_maker.OpRole.Backward +DEVICE_LIST = ["cpu", "gpu", "xpu"] +COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"] SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"] SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} +SPARSE_GRAD_OP_TYPE_DICT = { + "lookup_table_grad": "W", + "lookup_table_v2_grad": "W" +} +DEFAULT_DEVICE = 'cpu' def logger_config(log_path, logging_name): @@ -640,6 +648,20 @@ def find_block_joints(program, program_block_ops_list, heter_ops): return block_var_detail +def find_ops_list_input_output(program, ops_list): + input_var_list = [] + output_var_list = [] + for op in ops_list: + inputs = _get_input_map_from_op(program.global_block().vars, op) + input_var_list += get_varlist_from_op_map(inputs) + outputs = _get_output_map_from_op(program.global_block().vars, op) + output_var_list += get_varlist_from_op_map(outputs) + + input_var_list = list(set(input_var_list)) + output_var_list = list(set(output_var_list)) + return input_var_list, output_var_list + + def find_entrance_exit_private(program, program_block_ops_list): block_var_detail = [] persistables = [] @@ -850,6 +872,54 @@ def _get_output_map_from_op(varmap, op): return iomap +def get_varlist_from_op_map(var_map): + var_list = [] + for key, varlist in six.iteritems(var_map): + if not isinstance(varlist, list): + varlist = [varlist] + for i in range(len(varlist)): + var = varlist[i] + var_list.append(var.name) + return var_list + + +def _get_input_map_from_op(varmap, op): + """Returns a dict from op input name to the vars in varmap.""" + iomap = collections.OrderedDict() + for key in op.input_names: + vars = [] + for varname in op.input(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + vars.append(varmap[varname]) + if len(vars) == 1: + iomap[key] = vars[0] + else: + iomap[key] = vars + return iomap + + +def screen_persistables(program, var_list): + need_remove = [] + for var_name in var_list: + if "@GRAD" in var_name: + if "GRAD" != var_name.split("@")[-1]: + continue + origin_var_name = var_name.split("@GRAD")[0] + var = program.global_block().vars[origin_var_name] + else: + var = program.global_block().vars[var_name] + + if fluid.io.is_persistable(var): + need_remove.append(var_name) + + for var_name in need_remove: + var_list.remove(var_name) + return need_remove + + def block_append_op(program, origin_program, block, op): merge_ordereddict = origin_program.global_block().vars.copy() merge_ordereddict.update(block.vars) @@ -1154,6 +1224,84 @@ def get_param_grads(origin_program): return sparse_param_grads, dense_param_grads +def delete_ops(block, ops): + for op in ops: + try: + idx = list(block.ops).index(op) + block._remove_op(idx) + except Exception as e: + print(e) + + +def find_send_op(program): + send_op_list = [] + for op in program.global_block().ops: + if op.type == "send": + send_op_list.append(op) + return send_op_list + + +def find_op_input_output(program, block, op): + input_var_list = [] + output_var_list = [] + inputs = _get_input_map_from_op(block.vars, op) + input_var_list += get_varlist_from_op_map(inputs) + outputs = _get_output_map_from_op(block.vars, op) + output_var_list += get_varlist_from_op_map(outputs) + input_var_list = list(set(input_var_list)) + output_var_list = list(set(output_var_list)) + return input_var_list, output_var_list + + +def get_vars_name_in_block(block): + vars_list = block.vars.keys() + vars_name_list = [var_name for var_name in vars_list] + return vars_name_list + + +def delete_trainer_useless_var(program, static_var): + static_var = list(set(static_var)) + program_useful_var_list = [] + for op in program.global_block().ops: + input_var_list, output_var_list = find_op_input_output( + program, program.global_block(), op) + op_var_list = list(set(input_var_list).union(set(output_var_list))) + program_useful_var_list = list( + set(program_useful_var_list).union(set(op_var_list))) + program_useful_var_list += static_var + program_useless_var_list = list( + set(get_vars_name_in_block(program.global_block())).difference( + set(program_useful_var_list))) + for var in program_useless_var_list: + program.global_block()._remove_var(var) + return program_useless_var_list + + +def create_backward_block(program, origin_program, bp_ops_list, + block_var_detail): + pre_block_idx = program.num_blocks - 1 + heter_block = program._create_block(pre_block_idx) + + for _, op in enumerate(bp_ops_list): + if op.type == "send": + send_varnames = op.attr('send_varnames') + is_skip = False + for varname in send_varnames: + if varname not in program.global_block( + ).vars and varname not in heter_block.vars: + is_skip = True + break + if is_skip == True: + continue + block_append_op(program, origin_program, heter_block, op) + + entrance_vars = block_var_detail[0]["backward"]["entrance"] + add_vars_by_var_list(entrance_vars, origin_program, program, heter_block) + exit_vars = block_var_detail[0]["backward"]["exit"] + add_vars_by_var_list(exit_vars, origin_program, program, heter_block) + return heter_block + + def debug_program(file, program, is_trainer): if is_trainer: with open(file, 'w+') as f: diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index a612a4013713ee660b8ccd141f59894992c05178..e0c594b07aeb519dcb3906cdfc03d9af92117059 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -21,6 +21,17 @@ from paddle.fluid import core from paddle.fluid import framework from paddle import _C_ops +final_state_name_mapping = { + "matmul_v2": { + "final_op_name": "final_state_matmul", + "transpose_x": "trans_x", + "transpose_y": "trans_y", + "x": "X", + "y": "Y", + "out": "Out", + } +} + class Tracer(core.Tracer): """ @@ -40,6 +51,169 @@ class Tracer(core.Tracer): self._train_mode = True + def eager_trace_op(self, + type, + inputs, + outputs, + attrs, + stop_gradient=False, + inplace_map=None): + function_ptr = _C_ops.__dict__[type] + + core_ops_args_info = _C_ops.get_core_ops_args_info() + core_ops_args_type_info = _C_ops.get_core_ops_args_type_info() + core_ops_returns_info = _C_ops.get_core_ops_returns_info() + + op_args = core_ops_args_info[type] + op_args_type = core_ops_args_type_info[type] + op_returns = core_ops_returns_info[type] + + arg_list = [] + for i in range(len(op_args)): + arg_name = op_args[i] + arg_type = op_args_type[i] + if arg_name in inputs.keys(): + arg_to_append = inputs[arg_name] + elif arg_name in outputs.keys(): + arg_to_append = outputs[arg_name] + else: + if "Num" in arg_name: + # Remove "Num" suffix to get out_name + out_name = arg_name[:-3] + assert out_name in outputs.keys() + num_outs = len(outputs[out_name]) + arg_to_append = num_outs + else: + arg_to_append = None + + if arg_to_append is None: + arg_list.append(arg_to_append) + elif arg_type == "tensor": + if isinstance(arg_to_append, list): + arg_list.append(arg_to_append[0]) + else: + arg_list.append(arg_to_append) + elif arg_type == "list": + assert isinstance(arg_to_append, list) + arg_list.append(arg_to_append) + else: + assert arg_type == "int" + assert isinstance(arg_to_append, int) + arg_list.append(arg_to_append) + + attrs_list = [] + for k, v in attrs.items(): + attrs_list.append(k) + attrs_list.append(v) + returns = function_ptr(*arg_list, *attrs_list) + + if isinstance(returns, tuple): + for i in range(len(op_returns)): + retname = op_returns[i] + if retname in outputs.keys(): + # Replaced outputs by function returns + if isinstance(returns[i], list): + for j in range(len(returns[i])): + outputs[retname][j].reconstruct_from_(returns[i][j], + False) + else: + outputs[retname][0].reconstruct_from_(returns[i], False) + elif isinstance(returns, list): + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + for j in range(len(returns)): + outputs[key][j].reconstruct_from_(returns[j], False) + else: + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + if isinstance(outputs[key], list): + outputs[key][0].reconstruct_from_(returns, False) + else: + outputs[key].reconstruct_from_(returns, False) + + def eager_final_state_trace_op(self, + type, + inputs, + outputs, + attrs, + stop_gradient=False, + inplace_map=None): + assert type in final_state_name_mapping.keys() + + final_state_type = final_state_name_mapping[type]["final_op_name"] + function_ptr = _C_ops.__dict__[final_state_type] + + core_ops_args_info = _C_ops.get_final_state_core_ops_args_info() + core_ops_args_type_info = _C_ops.get_final_state_core_ops_args_type_info( + ) + core_ops_returns_info = _C_ops.get_final_state_core_ops_returns_info() + + op_args = core_ops_args_info[final_state_type] + op_args_type = core_ops_args_type_info[final_state_type] + op_returns = core_ops_returns_info[final_state_type] + + arg_list = [] + for i in range(len(op_args)): + eager_arg_name = op_args[i] + arg_type = op_args_type[i] + + assert eager_arg_name in final_state_name_mapping[type].keys() + arg_name = final_state_name_mapping[type][eager_arg_name] + + if arg_name in inputs.keys(): + arg_to_append = inputs[arg_name] + elif arg_name in outputs.keys(): + arg_to_append = outputs[arg_name] + elif arg_name in attrs.keys() and arg_type == "": + arg_to_append = attrs[arg_name] + else: + # dispensable + arg_to_append = None + + if arg_type == "": + # attribute + arg_list.append(arg_to_append) + elif arg_type == "tensor": + if isinstance(arg_to_append, list): + arg_list.append(arg_to_append[0]) + else: + arg_list.append(arg_to_append) + elif arg_type == "list": + assert isinstance(arg_to_append, list) + arg_list.append(arg_to_append) + else: + assert arg_to_append is None + arg_list.append(arg_to_append) + + returns = function_ptr(*arg_list) + + if isinstance(returns, tuple): + for i in range(len(op_returns)): + eager_retname = op_returns[i] + + assert eager_retname in final_state_name_mapping[type].keys() + retname = final_state_name_mapping[type][eager_retname] + if retname in outputs.keys(): + # Replaced outputs by function returns + if isinstance(returns[i], list): + for j in range(len(returns[i])): + outputs[retname][j].reconstruct_from_(returns[i][j], + False) + else: + outputs[retname][0].reconstruct_from_(returns[i], False) + elif isinstance(returns, list): + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + for j in range(len(returns)): + outputs[key][j].reconstruct_from_(returns[j], False) + else: + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + if isinstance(outputs[key], list): + outputs[key][0].reconstruct_from_(returns, False) + else: + outputs[key].reconstruct_from_(returns, False) + def trace_op(self, type, inputs, @@ -51,78 +225,16 @@ class Tracer(core.Tracer): # inputs : {"sum": [tensor], ...} # outputs : {"sum": [tensor], ...} - function_ptr = _C_ops.__dict__[type] - - core_ops_args_info = _C_ops.get_core_ops_args_info() - core_ops_args_type_info = _C_ops.get_core_ops_args_type_info() - core_ops_returns_info = _C_ops.get_core_ops_returns_info() - - op_args = core_ops_args_info[type] - op_args_type = core_ops_args_type_info[type] - op_returns = core_ops_returns_info[type] - - arg_list = [] - for i in range(len(op_args)): - arg_name = op_args[i] - arg_type = op_args_type[i] - if arg_name in inputs.keys(): - arg_to_append = inputs[arg_name] - elif arg_name in outputs.keys(): - arg_to_append = outputs[arg_name] - else: - if "Num" in arg_name: - # Remove "Num" suffix to get out_name - out_name = arg_name[:-3] - assert out_name in outputs.keys() - num_outs = len(outputs[out_name]) - arg_to_append = num_outs - else: - arg_to_append = None + if type in final_state_name_mapping.keys(): + final_state_type = final_state_name_mapping[type][ + "final_op_name"] - if arg_to_append is None: - arg_list.append(arg_to_append) - elif arg_type == "tensor": - if isinstance(arg_to_append, list): - arg_list.append(arg_to_append[0]) - else: - arg_list.append(arg_to_append) - elif arg_type == "list": - assert isinstance(arg_to_append, list) - arg_list.append(arg_to_append) - else: - assert arg_type == "int" - assert isinstance(arg_to_append, int) - arg_list.append(arg_to_append) - - attrs_list = [] - for k, v in attrs.items(): - attrs_list.append(k) - attrs_list.append(v) - returns = function_ptr(*arg_list, *attrs_list) - - if isinstance(returns, tuple): - for i in range(len(op_returns)): - retname = op_returns[i] - if retname in outputs.keys(): - # Replaced outputs by function returns - if isinstance(returns[i], list): - for j in range(len(returns[i])): - outputs[retname][j].reconstruct_from_(returns[i] - [j]) - else: - outputs[retname][0].reconstruct_from_(returns[i]) - elif isinstance(returns, list): - assert len(outputs.keys()) == 1 - key = list(outputs.keys())[0] - for j in range(len(returns)): - outputs[key][j].reconstruct_from_(returns[j]) + assert final_state_type in _C_ops.__dict__ + self.eager_final_state_trace_op(type, inputs, outputs, attrs, + stop_gradient, inplace_map) else: - assert len(outputs.keys()) == 1 - key = list(outputs.keys())[0] - if isinstance(outputs[key], list): - outputs[key][0].reconstruct_from_(returns) - else: - outputs[key].reconstruct_from_(returns) + self.eager_trace_op(type, inputs, outputs, attrs, stop_gradient, + inplace_map) else: self.trace(type, inputs, outputs, attrs, framework._current_expected_place(), self._has_grad and diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/ps_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/ps_pass_test_base.py index 1848fa04b4f50fbe85c4d139eedc3f48fc48b529..63dd4b8e21e074654c528a6807e7c91e0b32141c 100755 --- a/python/paddle/fluid/tests/unittests/distributed_passes/ps_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/ps_pass_test_base.py @@ -22,6 +22,7 @@ import inspect import unittest import numpy as np from collections import OrderedDict +from paddle.distributed.ps.utils.public import logger from dist_pass_test_base import prepare_python_path_and_return_module, remove_path_if_exists import paddle.distributed.fleet as fleet @@ -37,7 +38,7 @@ class PsPassTestBase(unittest.TestCase): print('Ps tearDown...') def ps_launch(self, config, ps_mode="cpu-ps"): - if ps_mode == "cpu-ps": + if ps_mode == "cpu-ps" or ps_mode == 'heter-ps': os.environ['WITH_DISTRIBUTE'] = 'ON' cmd = [ @@ -45,7 +46,16 @@ class PsPassTestBase(unittest.TestCase): "-u", ] + [ "-m", "launch", "--log_dir", config['log_dir'], "--worker_num", - config['worker_num'], "--server_num", config['server_num'], + config['worker_num'], "--server_num", config['server_num'] + ] + if ps_mode == 'heter-ps': + os.environ['FLAGS_START_PORT'] = '12004' + cmd += [ + '--heter_worker_num', config['heter_worker_num'], + '--heter_devices', config['heter_devices'] + ] + + cmd += [ "../ps/ps_dnn_trainer.py", "-m", config['ps_mode_config'], "--run_minimize", config['run_minimize'], "--run_single_pass", config['run_single_pass'], "--debug_new_pass", diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_ps_trainer_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_ps_trainer_pass.py index ac6dd17359e32d28612ae892fc698f61b11fa12b..f28e99fc00d97ae13689be208bd3b10727f053ef 100755 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_ps_trainer_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_ps_trainer_pass.py @@ -63,6 +63,27 @@ class TestPsTrainerPass(PsPassTestBase): self.check() + # heter ps 三阶段待测 + def test_ps_optimizer_minimize_heter(self): + self.init() + self.config['worker_num'] = "2" + self.config['server_num'] = "2" + self.config['heter_worker_num'] = '2' + self.config['heter_devices'] = 'gpu' + + self.config['run_minimize'] = '1' + self.config['ps_mode_config'] = "../ps/heter_ps_config.yaml" + + self.config['debug_new_minimize'] = '0' + self.config['log_dir'] = "/heter_log_old_minimize" + remove_path_if_exists(self.config['log_dir']) + self.ps_launch(self.config, 'heter-ps') + + self.config['debug_new_minimize'] = '1' + self.config['log_dir'] = "/heter_log_new_minimize" + remove_path_if_exists(self.config['log_dir']) + self.ps_launch(self.config, 'heter-ps') + def test_ps_optimizer_minimize_gpu(self): self.init() self.config['run_minimize'] = '1' diff --git a/python/paddle/fluid/tests/unittests/ps/heter_ps_config.yaml b/python/paddle/fluid/tests/unittests/ps/heter_ps_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0c48e242d91d50cbb95ffc73e32528837030485 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ps/heter_ps_config.yaml @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +hyper_parameters: + optimizer: + class: Adam + learning_rate: 0.0001 + strategy: async # 有用 + sparse_inputs_slots: 27 + sparse_feature_number: 1024 + sparse_feature_dim: 11 + dense_input_dim: 13 + fc_sizes: [512, 256, 128, 32] + distributed_embedding: 0 + +runner: + sync_mode: "heter" + thread_num: 8 + micro_num: 8 # micro batch num for each thread + pipeline: True + + model_path: "../ps_dnn_model.py" + + diff --git a/python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py b/python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py index 2b6ce2e71130ae1dd3f4246505587ea9b66d3bd1..8f8ff65af544a1c4ddb4f1548603b418d3bf8bed 100755 --- a/python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py +++ b/python/paddle/fluid/tests/unittests/ps/ps_dnn_trainer.py @@ -23,7 +23,6 @@ import yaml, six, copy import paddle import os import warnings -import logging import ast import numpy as np import struct @@ -176,6 +175,10 @@ def get_user_defined_strategy(config): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = True strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"} + strategy.pipeline = True + strategy.pipeline_configs = { + "accumulate_steps": config.get('runner.micro_num') + } elif sync_mode == "gpubox": print("sync_mode = {}".format(sync_mode)) strategy = paddle.distributed.fleet.DistributedStrategy() @@ -328,6 +331,7 @@ class DnnTrainer(object): if self.config['debug_new_minimize'] == 1: logger.info("entering run_minimize -- new") + self.role_maker._generate_role() # 必要 from paddle.distributed.fleet.meta_optimizers.ps_optimizer import ParameterServerOptimizer ps_optimizer = ParameterServerOptimizer(inner_optimizer) ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer, diff --git a/python/paddle/fluid/tests/unittests/ps_dnn_model.py b/python/paddle/fluid/tests/unittests/ps_dnn_model.py old mode 100644 new mode 100755 index 1a42df030b3bffe49151f9d2a1270347418e03b4..0a147334dab264eb8846bcf216bfc38da1d43b02 --- a/python/paddle/fluid/tests/unittests/ps_dnn_model.py +++ b/python/paddle/fluid/tests/unittests/ps_dnn_model.py @@ -17,6 +17,7 @@ import paddle.nn as nn import paddle.nn.functional as F import math import paddle.distributed.fleet as fleet +from paddle.distributed.ps.utils.public import logger class DNNLayer(nn.Layer): @@ -77,8 +78,13 @@ class DNNLayer(nn.Layer): y_dnn = paddle.concat(x=sparse_embs + [dense_inputs], axis=1) - for n_layer in self._mlp_layers: - y_dnn = n_layer(y_dnn) + if self.sync_mode == 'heter': + with paddle.fluid.device_guard('gpu'): + for n_layer in self._mlp_layers: + y_dnn = n_layer(y_dnn) + else: + for n_layer in self._mlp_layers: + y_dnn = n_layer(y_dnn) return y_dnn diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py index 5b92fcf52def0cd0788787606813d65afec3bb84..d601117b96f12d35756b521b85902bf91ef01bae 100644 --- a/python/paddle/fluid/tests/unittests/test_where_op.py +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -1,11 +1,11 @@ -#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import print_function - import unittest import numpy as np import paddle @@ -24,38 +23,39 @@ from op_test import OpTest from paddle.fluid import compiler, Program, program_guard from paddle.fluid.op import Operator from paddle.fluid.backward import append_backward +from paddle.fluid.framework import _test_eager_guard class TestWhereOp(OpTest): def setUp(self): - self.op_type = "where" + self.op_type = 'where' self.init_config() self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y} self.outputs = {'Out': np.where(self.cond, self.x, self.y)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) def init_config(self): - self.x = np.random.uniform(-3, 5, (100)).astype("float64") - self.y = np.random.uniform(-3, 5, (100)).astype("float64") - self.cond = np.zeros((100)).astype("bool") + self.x = np.random.uniform((-3), 5, 100).astype('float64') + self.y = np.random.uniform((-3), 5, 100).astype('float64') + self.cond = np.zeros(100).astype('bool') class TestWhereOp2(TestWhereOp): def init_config(self): - self.x = np.random.uniform(-5, 5, (60, 2)).astype("float64") - self.y = np.random.uniform(-5, 5, (60, 2)).astype("float64") - self.cond = np.ones((60, 2)).astype("bool") + self.x = np.random.uniform((-5), 5, (60, 2)).astype('float64') + self.y = np.random.uniform((-5), 5, (60, 2)).astype('float64') + self.cond = np.ones((60, 2)).astype('bool') class TestWhereOp3(TestWhereOp): def init_config(self): - self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64") - self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64") + self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64') + self.y = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64') self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool) @@ -66,15 +66,15 @@ class TestWhereAPI(unittest.TestCase): def init_data(self): self.shape = [10, 15] self.cond = np.array(np.random.randint(2, size=self.shape), dtype=bool) - self.x = np.random.uniform(-2, 3, self.shape).astype(np.float32) - self.y = np.random.uniform(-2, 3, self.shape).astype(np.float32) + self.x = np.random.uniform((-2), 3, self.shape).astype(np.float32) + self.y = np.random.uniform((-2), 3, self.shape).astype(np.float32) self.out = np.where(self.cond, self.x, self.y) def ref_x_backward(self, dout): - return np.where(self.cond == True, dout, 0) + return np.where((self.cond == True), dout, 0) def ref_y_backward(self, dout): - return np.where(self.cond == False, dout, 0) + return np.where((self.cond == False), dout, 0) def test_api(self, use_cuda=False): for x_stop_gradient in [False, True]: @@ -90,17 +90,17 @@ class TestWhereAPI(unittest.TestCase): y.stop_gradient = y_stop_gradient result = paddle.where(cond, x, y) append_backward(layers.mean(result)) - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): + if (use_cuda and + (not fluid.core.is_compiled_with_cuda())): break - place = fluid.CUDAPlace( - 0) if use_cuda else fluid.CPUPlace() + place = (fluid.CUDAPlace(0) + if use_cuda else fluid.CPUPlace()) exe = fluid.Executor(place) fetch_list = [result, result.grad_name] - if x_stop_gradient is False: + if (x_stop_gradient is False): fetch_list.append(x.grad_name) - if y_stop_gradient is False: + if (y_stop_gradient is False): fetch_list.append(y.grad_name) out = exe.run( fluid.default_main_program(), @@ -109,13 +109,13 @@ class TestWhereAPI(unittest.TestCase): 'y': self.y}, fetch_list=fetch_list) assert np.array_equal(out[0], self.out) - if x_stop_gradient is False: + if (x_stop_gradient is False): assert np.array_equal(out[2], self.ref_x_backward(out[1])) - if y.stop_gradient is False: + if (y.stop_gradient is False): assert np.array_equal( out[3], self.ref_y_backward(out[1])) - elif y.stop_gradient is False: + elif (y.stop_gradient is False): assert np.array_equal(out[2], self.ref_y_backward(out[1])) @@ -124,44 +124,38 @@ class TestWhereAPI(unittest.TestCase): with fluid.program_guard(main_program): x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32') y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32') - x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") - y_i = np.array([[1.0, 1.0, 1.0, 1.0], - [1.0, 1.0, 1.0, 1.0]]).astype("float32") - result = paddle.where(x > 1, x=x, y=y) - + x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype('float32') + y_i = np.array( + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype('float32') + result = paddle.where((x > 1), x=x, y=y) for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): + if (use_cuda and (not fluid.core.is_compiled_with_cuda())): return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()) exe = fluid.Executor(place) out = exe.run(fluid.default_main_program(), feed={'x': x_i, 'y': y_i}, fetch_list=[result]) - assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i)) + assert np.array_equal(out[0], np.where((x_i > 1), x_i, y_i)) def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): paddle.enable_static() - main_program = Program() with fluid.program_guard(main_program): cond = fluid.layers.data( name='cond', shape=cond_shape, dtype='bool') x = fluid.layers.data(name='x', shape=x_shape, dtype='float32') y = fluid.layers.data(name='y', shape=y_shape, dtype='float32') - - cond_data_tmp = np.random.random(size=cond_shape).astype("float32") - cond_data = cond_data_tmp < 0.3 - x_data = np.random.random(size=x_shape).astype("float32") - y_data = np.random.random(size=y_shape).astype("float32") - + cond_data_tmp = np.random.random(size=cond_shape).astype('float32') + cond_data = (cond_data_tmp < 0.3) + x_data = np.random.random(size=x_shape).astype('float32') + y_data = np.random.random(size=y_shape).astype('float32') result = paddle.where(condition=cond, x=x, y=y) - for use_cuda in [False, True]: - if use_cuda and not fluid.core.is_compiled_with_cuda(): + if (use_cuda and (not fluid.core.is_compiled_with_cuda())): return - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - + place = (fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()) exe = fluid.Executor(place) out = exe.run( fluid.default_main_program(), @@ -169,9 +163,7 @@ class TestWhereAPI(unittest.TestCase): 'x': x_data, 'y': y_data}, fetch_list=[result]) - expect = np.where(cond_data, x_data, y_data) - assert np.array_equal(out[0], expect) def test_static_api_broadcast_1(self): @@ -198,28 +190,24 @@ class TestWhereAPI(unittest.TestCase): b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_static_api_broadcast_5(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_static_api_broadcast_6(self): cond_shape = [2, 2, 4] a_shape = [2, 2, 1] b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_static_api_broadcast_7(self): cond_shape = [2, 2, 4] a_shape = [2, 1, 4] b_shape = [2, 1, 4] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_static_api_broadcast_8(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 1] @@ -230,9 +218,9 @@ class TestWhereAPI(unittest.TestCase): class TestWhereDygraphAPI(unittest.TestCase): def test_api(self): with fluid.dygraph.guard(): - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") - cond_i = np.array([False, False, True, True]).astype("bool") + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64') + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64') + cond_i = np.array([False, False, True, True]).astype('bool') x = fluid.dygraph.to_variable(x_i) y = fluid.dygraph.to_variable(y_i) cond = fluid.dygraph.to_variable(cond_i) @@ -242,15 +230,12 @@ class TestWhereDygraphAPI(unittest.TestCase): def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape): with fluid.dygraph.guard(): cond_tmp = paddle.rand(cond_shape) - cond = cond_tmp < 0.3 + cond = (cond_tmp < 0.3) a = paddle.rand(a_shape) b = paddle.rand(b_shape) - result = paddle.where(cond, a, b) result = result.numpy() - expect = np.where(cond, a, b) - self.assertTrue(np.array_equal(expect, result)) def test_dygraph_api_broadcast_1(self): @@ -277,28 +262,24 @@ class TestWhereDygraphAPI(unittest.TestCase): b_shape = [2, 2, 4] self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_dygraph_api_broadcast_5(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 4] b_shape = [2, 2, 4] self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_dygraph_api_broadcast_6(self): cond_shape = [2, 2, 4] a_shape = [2, 2, 1] b_shape = [2, 2, 1] self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_dygraph_api_broadcast_7(self): cond_shape = [2, 2, 4] a_shape = [2, 1, 4] b_shape = [2, 1, 4] self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape) - # @Note Now, maybe not compatibility with old version def test_dygraph_api_broadcast_8(self): cond_shape = [3, 2, 2, 4] a_shape = [2, 2, 1] @@ -308,40 +289,50 @@ class TestWhereDygraphAPI(unittest.TestCase): def test_where_condition(self): data = np.array([[True, False], [False, True]]) with program_guard(Program(), Program()): - x = fluid.layers.data(name='x', shape=[-1, 2]) + x = fluid.layers.data(name='x', shape=[(-1), 2]) y = paddle.where(x) self.assertEqual(type(y), tuple) self.assertEqual(len(y), 2) z = fluid.layers.concat(list(y), axis=1) exe = fluid.Executor(fluid.CPUPlace()) - - res, = exe.run(feed={'x': data}, - fetch_list=[z.name], - return_numpy=False) + (res, ) = exe.run(feed={'x': data}, + fetch_list=[z.name], + return_numpy=False) expect_out = np.array([[0, 0], [1, 1]]) self.assertTrue(np.allclose(expect_out, np.array(res))) - data = np.array([True, True, False]) with program_guard(Program(), Program()): - x = fluid.layers.data(name='x', shape=[-1]) + x = fluid.layers.data(name='x', shape=[(-1)]) y = paddle.where(x) self.assertEqual(type(y), tuple) self.assertEqual(len(y), 1) z = fluid.layers.concat(list(y), axis=1) exe = fluid.Executor(fluid.CPUPlace()) - res, = exe.run(feed={'x': data}, - fetch_list=[z.name], - return_numpy=False) + (res, ) = exe.run(feed={'x': data}, + fetch_list=[z.name], + return_numpy=False) expect_out = np.array([[0], [1]]) self.assertTrue(np.allclose(expect_out, np.array(res))) + def test_eager(self): + with _test_eager_guard(): + self.test_api() + self.test_dygraph_api_broadcast_1() + self.test_dygraph_api_broadcast_2() + self.test_dygraph_api_broadcast_3() + self.test_dygraph_api_broadcast_4() + self.test_dygraph_api_broadcast_5() + self.test_dygraph_api_broadcast_6() + self.test_dygraph_api_broadcast_7() + self.test_dygraph_api_broadcast_8() + class TestWhereOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") - y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") - cond_i = np.array([False, False, True, True]).astype("bool") + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype('float64') + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype('float64') + cond_i = np.array([False, False, True, True]).astype('bool') def test_Variable(): paddle.where(cond_i, x_i, y_i) @@ -360,10 +351,14 @@ class TestWhereOpError(unittest.TestCase): with fluid.dygraph.guard(): cond_shape = [2, 2, 4] cond_tmp = paddle.rand(cond_shape) - cond = cond_tmp < 0.3 + cond = (cond_tmp < 0.3) a = paddle.rand(cond_shape) self.assertRaises(ValueError, paddle.where, cond, a) + def test_eager(self): + with _test_eager_guard(): + self.test_value_error() + -if __name__ == '__main__': +if (__name__ == '__main__'): unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 5793f0148fc5475a89c3b53831bc2019af542b61..043c5c1651a09ac022d8a694b2e916b613c77f6b 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -1,11 +1,11 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,23 +13,22 @@ # limitations under the License. from __future__ import division - import unittest import numpy as np from op_test import OpTest - import paddle from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard def sigmoid(x): - return 1.0 / (1.0 + np.exp(-1.0 * x)) + return (1.0 / (1.0 + np.exp(((-1.0) * x)))) def YoloBox(x, img_size, attrs): - n, c, h, w = x.shape + (n, c, h, w) = x.shape anchors = attrs['anchors'] - an_num = int(len(anchors) // 2) + an_num = int((len(anchors) // 2)) class_num = attrs['class_num'] conf_thresh = attrs['conf_thresh'] downsample = attrs['downsample'] @@ -37,60 +36,56 @@ def YoloBox(x, img_size, attrs): scale_x_y = attrs['scale_x_y'] iou_aware = attrs['iou_aware'] iou_aware_factor = attrs['iou_aware_factor'] - bias_x_y = -0.5 * (scale_x_y - 1.) - input_h = downsample * h - input_w = downsample * w - + bias_x_y = ((-0.5) * (scale_x_y - 1.0)) + input_h = (downsample * h) + input_w = (downsample * w) if iou_aware: ioup = x[:, :an_num, :, :] - ioup = np.expand_dims(ioup, axis=-1) + ioup = np.expand_dims(ioup, axis=(-1)) x = x[:, an_num:, :, :] - x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) - + x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2)) pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) - pred_box[:, :, :, :, 0] = ( - grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w - pred_box[:, :, :, :, 1] = ( - grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h - - anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] + pred_box[:, :, :, :, 0] = (( + (grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y) / + w) + pred_box[:, :, :, :, 1] = (( + (grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y) / + h) + anchors = [(anchors[i], anchors[(i + 1)]) + for i in range(0, len(anchors), 2)] anchors_s = np.array( - [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors]) + [((an_w / input_w), (an_h / input_h)) for (an_w, an_h) in anchors]) anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) - pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w - pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h - + pred_box[:, :, :, :, 2] = (np.exp(pred_box[:, :, :, :, 2]) * anchor_w) + pred_box[:, :, :, :, 3] = (np.exp(pred_box[:, :, :, :, 3]) * anchor_h) if iou_aware: - pred_conf = sigmoid(x[:, :, :, :, 4:5])**( - 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor + pred_conf = ((sigmoid(x[:, :, :, :, 4:5])**(1 - iou_aware_factor)) * + (sigmoid(ioup)**iou_aware_factor)) else: pred_conf = sigmoid(x[:, :, :, :, 4:5]) - pred_conf[pred_conf < conf_thresh] = 0. - pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf - pred_box = pred_box * (pred_conf > 0.).astype('float32') - - pred_box = pred_box.reshape((n, -1, 4)) - pred_box[:, :, :2], pred_box[:, :, 2:4] = \ - pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \ - pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0 - pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis] - pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis] - pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] - pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] - + pred_conf[(pred_conf < conf_thresh)] = 0.0 + pred_score = (sigmoid(x[:, :, :, :, 5:]) * pred_conf) + pred_box = (pred_box * (pred_conf > 0.0).astype('float32')) + pred_box = pred_box.reshape((n, (-1), 4)) + (pred_box[:, :, :2], pred_box[:, :, 2:4]) = ( + (pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)), + (pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0))) + pred_box[:, :, 0] = (pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]) + pred_box[:, :, 1] = (pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]) + pred_box[:, :, 2] = (pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]) + pred_box[:, :, 3] = (pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]) if clip_bbox: for i in range(len(pred_box)): pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) - pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf, - img_size[i, 1] - 1) - pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf, - img_size[i, 0] - 1) - - return pred_box, pred_score.reshape((n, -1, class_num)) + pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], (-np.inf), + (img_size[(i, 1)] - 1)) + pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], (-np.inf), + (img_size[(i, 0)] - 1)) + return (pred_box, pred_score.reshape((n, (-1), class_num))) class TestYoloBoxOp(OpTest): @@ -99,42 +94,35 @@ class TestYoloBoxOp(OpTest): self.op_type = 'yolo_box' x = np.random.random(self.x_shape).astype('float32') img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') - self.attrs = { - "anchors": self.anchors, - "class_num": self.class_num, - "conf_thresh": self.conf_thresh, - "downsample": self.downsample, - "clip_bbox": self.clip_bbox, - "scale_x_y": self.scale_x_y, - "iou_aware": self.iou_aware, - "iou_aware_factor": self.iou_aware_factor - } - - self.inputs = { - 'X': x, - 'ImgSize': img_size, - } - boxes, scores = YoloBox(x, img_size, self.attrs) - self.outputs = { - "Boxes": boxes, - "Scores": scores, + 'anchors': self.anchors, + 'class_num': self.class_num, + 'conf_thresh': self.conf_thresh, + 'downsample': self.downsample, + 'clip_bbox': self.clip_bbox, + 'scale_x_y': self.scale_x_y, + 'iou_aware': self.iou_aware, + 'iou_aware_factor': self.iou_aware_factor } + self.inputs = {'X': x, 'ImgSize': img_size} + (boxes, scores) = YoloBox(x, img_size, self.attrs) + self.outputs = {'Boxes': boxes, 'Scores': scores} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] - an_num = int(len(self.anchors) // 2) + an_num = int((len(self.anchors) // 2)) self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 self.clip_bbox = True - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) self.imgsize_shape = (self.batch_size, 2) - self.scale_x_y = 1. + self.scale_x_y = 1.0 self.iou_aware = False self.iou_aware_factor = 0.5 @@ -142,15 +130,16 @@ class TestYoloBoxOp(OpTest): class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] - an_num = int(len(self.anchors) // 2) + an_num = int((len(self.anchors) // 2)) self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 self.clip_bbox = False - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) self.imgsize_shape = (self.batch_size, 2) - self.scale_x_y = 1. + self.scale_x_y = 1.0 self.iou_aware = False self.iou_aware_factor = 0.5 @@ -158,13 +147,14 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): class TestYoloBoxOpScaleXY(TestYoloBoxOp): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] - an_num = int(len(self.anchors) // 2) + an_num = int((len(self.anchors) // 2)) self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 self.clip_bbox = True - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) self.imgsize_shape = (self.batch_size, 2) self.scale_x_y = 1.2 self.iou_aware = False @@ -174,15 +164,16 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp): class TestYoloBoxOpIoUAware(TestYoloBoxOp): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] - an_num = int(len(self.anchors) // 2) + an_num = int((len(self.anchors) // 2)) self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 self.clip_bbox = True - self.x_shape = (self.batch_size, an_num * (6 + self.class_num), 13, 13) + self.x_shape = (self.batch_size, (an_num * (6 + self.class_num)), 13, + 13) self.imgsize_shape = (self.batch_size, 2) - self.scale_x_y = 1. + self.scale_x_y = 1.0 self.iou_aware = True self.iou_aware_factor = 0.5 @@ -192,10 +183,9 @@ class TestYoloBoxDygraph(unittest.TestCase): paddle.disable_static() img_size = np.ones((2, 2)).astype('int32') img_size = paddle.to_tensor(img_size) - x1 = np.random.random([2, 14, 8, 8]).astype('float32') x1 = paddle.to_tensor(x1) - boxes, scores = paddle.vision.ops.yolo_box( + (boxes, scores) = paddle.vision.ops.yolo_box( x1, img_size=img_size, anchors=[10, 13, 16, 30], @@ -203,12 +193,11 @@ class TestYoloBoxDygraph(unittest.TestCase): conf_thresh=0.01, downsample_ratio=8, clip_bbox=True, - scale_x_y=1.) - assert boxes is not None and scores is not None - + scale_x_y=1.0) + assert ((boxes is not None) and (scores is not None)) x2 = np.random.random([2, 16, 8, 8]).astype('float32') x2 = paddle.to_tensor(x2) - boxes, scores = paddle.vision.ops.yolo_box( + (boxes, scores) = paddle.vision.ops.yolo_box( x2, img_size=img_size, anchors=[10, 13, 16, 30], @@ -216,18 +205,21 @@ class TestYoloBoxDygraph(unittest.TestCase): conf_thresh=0.01, downsample_ratio=8, clip_bbox=True, - scale_x_y=1., + scale_x_y=1.0, iou_aware=True, iou_aware_factor=0.5) paddle.enable_static() + def test_eager(self): + with _test_eager_guard(): + self.test_dygraph() + class TestYoloBoxStatic(unittest.TestCase): def test_static(self): x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32') img_size = paddle.static.data('img_size', [2, 2], 'int32') - - boxes, scores = paddle.vision.ops.yolo_box( + (boxes, scores) = paddle.vision.ops.yolo_box( x1, img_size=img_size, anchors=[10, 13, 16, 30], @@ -235,11 +227,10 @@ class TestYoloBoxStatic(unittest.TestCase): conf_thresh=0.01, downsample_ratio=8, clip_bbox=True, - scale_x_y=1.) - assert boxes is not None and scores is not None - + scale_x_y=1.0) + assert ((boxes is not None) and (scores is not None)) x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32') - boxes, scores = paddle.vision.ops.yolo_box( + (boxes, scores) = paddle.vision.ops.yolo_box( x2, img_size=img_size, anchors=[10, 13, 16, 30], @@ -247,27 +238,27 @@ class TestYoloBoxStatic(unittest.TestCase): conf_thresh=0.01, downsample_ratio=8, clip_bbox=True, - scale_x_y=1., + scale_x_y=1.0, iou_aware=True, iou_aware_factor=0.5) - assert boxes is not None and scores is not None + assert ((boxes is not None) and (scores is not None)) class TestYoloBoxOpHW(TestYoloBoxOp): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] - an_num = int(len(self.anchors) // 2) + an_num = int((len(self.anchors) // 2)) self.batch_size = 32 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 self.clip_bbox = False - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9) + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9) self.imgsize_shape = (self.batch_size, 2) - self.scale_x_y = 1. + self.scale_x_y = 1.0 self.iou_aware = False self.iou_aware_factor = 0.5 -if __name__ == "__main__": +if (__name__ == '__main__'): unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zeros_like_op.py b/python/paddle/fluid/tests/unittests/test_zeros_like_op.py index 6546d7b99f44163e923fe9ed4d84acbe962a7995..80b4db793ff439d6858c6e74db869ac75bd5f23c 100644 --- a/python/paddle/fluid/tests/unittests/test_zeros_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_zeros_like_op.py @@ -1,11 +1,11 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,13 +13,13 @@ # limitations under the License. from __future__ import print_function - import unittest import numpy as np import paddle import paddle.fluid as fluid from paddle import zeros_like from paddle.fluid import core, Program, program_guard +from paddle.fluid.framework import _test_eager_guard class TestZerosLikeAPIError(unittest.TestCase): @@ -28,6 +28,10 @@ class TestZerosLikeAPIError(unittest.TestCase): x = paddle.fluid.data('x', [3, 4]) self.assertRaises(TypeError, zeros_like, x, 'int8') + def test_eager(self): + with _test_eager_guard(): + self.test_errors() + class TestZerosLikeAPI(unittest.TestCase): def test_api(self): @@ -36,46 +40,48 @@ class TestZerosLikeAPI(unittest.TestCase): train_program = Program() with program_guard(train_program, startup_program): x = paddle.fluid.data('X', shape) - - # 'bool', 'float32', 'float64', 'int32', 'int64' out1 = zeros_like(x) out2 = zeros_like(x, np.bool) out3 = zeros_like(x, 'float64') out4 = zeros_like(x, 'int32') out5 = zeros_like(x, 'int64') - - place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( - ) else fluid.CPUPlace() + place = (fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() else fluid.CPUPlace()) exe = fluid.Executor(place) outs = exe.run(train_program, feed={'X': np.ones(shape).astype('float32')}, fetch_list=[out1, out2, out3, out4, out5]) - - for i, dtype in enumerate( + for (i, dtype) in enumerate( [np.float32, np.bool, np.float64, np.int32, np.int64]): self.assertEqual(outs[i].dtype, dtype) self.assertEqual((outs[i] == np.zeros(shape, dtype)).all(), True) + def test_eager(self): + with _test_eager_guard(): + self.test_api() + class TestZerosLikeImpeartive(unittest.TestCase): def test_out(self): shape = [3, 4] - place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( - ) else fluid.CPUPlace() + place = (fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() else fluid.CPUPlace()) paddle.disable_static(place) x = paddle.to_tensor(np.ones(shape)) for dtype in [np.bool, np.float32, np.float64, np.int32, np.int64]: out = zeros_like(x, dtype) self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True) - out = paddle.tensor.zeros_like(x) self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True) - out = paddle.tensor.creation.zeros_like(x) self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True) paddle.enable_static() + def test_eager(self): + with _test_eager_guard(): + self.test_out() + -if __name__ == "__main__": +if (__name__ == '__main__'): unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zeros_op.py b/python/paddle/fluid/tests/unittests/test_zeros_op.py index 23dec935507fd977f884e952451b5ea98c935893..449f95aac297ae9f0210187f3a02a30561bb861a 100644 --- a/python/paddle/fluid/tests/unittests/test_zeros_op.py +++ b/python/paddle/fluid/tests/unittests/test_zeros_op.py @@ -1,11 +1,11 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,56 +13,55 @@ # limitations under the License. from __future__ import print_function - import unittest import numpy as np from op_test import OpTest - import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import _test_eager_guard class TestZerosOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - # The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64. shape = [4] - dtype = "int8" + dtype = 'int8' self.assertRaises(TypeError, fluid.layers.zeros, shape, dtype) + def test_eager(self): + with _test_eager_guard(): + self.test_errors() + class ApiZerosTest(unittest.TestCase): def test_out(self): with program_guard(Program()): - zeros = paddle.zeros(shape=[10], dtype="float64") + zeros = paddle.zeros(shape=[10], dtype='float64') place = paddle.CPUPlace() exe = paddle.static.Executor(place) - result, = exe.run(fetch_list=[zeros]) - expected_result = np.zeros(10, dtype="float64") + (result, ) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype='float64') self.assertEqual((result == expected_result).all(), True) - with paddle.static.program_guard(Program()): - zeros = paddle.zeros(shape=[10], dtype="int64") + zeros = paddle.zeros(shape=[10], dtype='int64') place = paddle.CPUPlace() exe = paddle.static.Executor(place) - result, = exe.run(fetch_list=[zeros]) - expected_result = np.zeros(10, dtype="int64") + (result, ) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype='int64') self.assertEqual((result == expected_result).all(), True) - with program_guard(Program()): - zeros = paddle.zeros(shape=[10], dtype="int64") + zeros = paddle.zeros(shape=[10], dtype='int64') place = paddle.CPUPlace() exe = paddle.static.Executor(place) - result, = exe.run(fetch_list=[zeros]) - expected_result = np.zeros(10, dtype="int64") + (result, ) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype='int64') self.assertEqual((result == expected_result).all(), True) - with program_guard(Program()): - out_np = np.zeros(shape=(1), dtype='float32') - out = paddle.zeros(shape=[1], dtype="float32") + out_np = np.zeros(shape=1, dtype='float32') + out = paddle.zeros(shape=[1], dtype='float32') place = paddle.CPUPlace() exe = paddle.static.Executor(place) result = exe.run(fetch_list=[out]) @@ -70,28 +69,37 @@ class ApiZerosTest(unittest.TestCase): def test_fluid_out(self): with program_guard(Program()): - zeros = fluid.layers.zeros(shape=[10], dtype="int64") + zeros = fluid.layers.zeros(shape=[10], dtype='int64') place = paddle.CPUPlace() exe = paddle.static.Executor(place) - result, = exe.run(fetch_list=[zeros]) - expected_result = np.zeros(10, dtype="int64") + (result, ) = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype='int64') self.assertEqual((result == expected_result).all(), True) + def test_eager(self): + with _test_eager_guard(): + self.test_out() + self.test_fluid_out() + class ApiZerosError(unittest.TestCase): def test_errors(self): def test_error1(): with paddle.static.program_guard(fluid.Program()): - ones = fluid.layers.zeros(shape=10, dtype="int64") + ones = fluid.layers.zeros(shape=10, dtype='int64') self.assertRaises(TypeError, test_error1) def test_error2(): with paddle.static.program_guard(fluid.Program()): - ones = fluid.layers.zeros(shape=[10], dtype="int8") + ones = fluid.layers.zeros(shape=[10], dtype='int8') self.assertRaises(TypeError, test_error2) + def test_eager(self): + with _test_eager_guard(): + self.test_errors() + -if __name__ == "__main__": +if (__name__ == '__main__'): unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index 31246436efae23f25eed998d673b71fd6e2c0377..4eae44846efc701d90a5a4ad03c6e0e29dad77c7 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -17,6 +17,7 @@ from __future__ import print_function import inspect import os import fcntl +import numpy as np import paddle import paddle.fluid.core as core @@ -29,28 +30,61 @@ type_dict_paddle_to_str = { paddle.int32: 'int32', paddle.int64: 'int64', paddle.float16: 'float16', + paddle.bfloat16: 'bfloat16', paddle.float32: 'float32', paddle.float64: 'float64', paddle.complex128: 'complex128', paddle.complex64: 'complex64', } +type_dict_paddle_to_numpy = { + paddle.bool: np.bool_, + paddle.uint8: np.uint8, + paddle.int8: np.int8, + paddle.int16: np.int16, + paddle.int32: np.int32, + paddle.int64: np.int64, + paddle.bfloat16: np.uint16, + paddle.float16: np.float16, + paddle.float32: np.float32, + paddle.float64: np.float64, + paddle.complex128: np.complex128, + paddle.complex64: np.complex64, +} + type_dict_str_to_paddle = { + 'uint8': paddle.uint8, + 'int8': paddle.int8, + 'int16': paddle.int16, 'int32': paddle.int32, 'int64': paddle.int64, - 'float32': paddle.float32, + 'bfloat16': paddle.bfloat16, 'float16': paddle.float16, + 'float32': paddle.float32, + 'float64': paddle.float64, 'bool': paddle.bool, - 'uint8': paddle.uint8, - 'int8': paddle.int8, - 'complex128': paddle.complex128, 'complex64': paddle.complex64, - 'int16': paddle.int16, + 'complex128': paddle.complex128, +} + +type_dict_str_to_numpy = { + 'uint8': np.uint8, + 'int8': np.int8, + 'int16': np.int16, + 'int32': np.int32, + 'int64': np.int64, + 'bfloat16': np.uint16, + 'float16': np.float16, + 'float32': np.float32, + 'float64': np.float64, + 'bool': np.bool_, + 'complex64': np.complex64, + 'complex128': np.complex128, } xpu_test_op_white_list = [] xpu_test_type_white_list = [] -xpu_test_op_type_white_list = [] +xpu_test_op_type_white_list = ['float64'] xpu_test_device_op_white_list = [] xpu_test_device_op_type_white_list = [] @@ -122,6 +156,8 @@ def make_xpu_op_list(xpu_version): if op_name in op_white_list or device_op_name in device_op_white_list: continue for op_type in type_list: + if op_type == paddle.bfloat16: + op_type = paddle.bfloat16 if op_type in type_white_list or op_type not in type_dict_paddle_to_str.keys( ): continue @@ -143,10 +179,17 @@ def get_xpu_op_support_types(op_name, dev_id=0): xpu_version = core.get_xpu_device_version(dev_id) support_type_list = core.get_xpu_device_op_support_types(op_name, xpu_version) - support_type_str_list = [ - type_dict_paddle_to_str[x] for x in support_type_list + support_type_str_list = [] + for stype in support_type_list: + if stype == paddle.bfloat16: + support_type_str_list.append(type_dict_paddle_to_str[ + paddle.bfloat16]) + else: + support_type_str_list.append(type_dict_paddle_to_str[stype]) + type_white_list = get_op_type_white_list() + return [ + stype for stype in support_type_str_list if stype not in type_white_list ] - return support_type_str_list def record_op_test(op_name, test_type): @@ -196,8 +239,9 @@ def create_test_class(func_globals, continue class_obj = test_class[1] cls_name = "{0}_{1}".format(test_class[0], str(test_type)) - func_globals[cls_name] = type(cls_name, (class_obj, ), - {'in_type': test_type}) + func_globals[cls_name] = type( + cls_name, (class_obj, ), + {'in_type': type_dict_str_to_numpy[test_type]}) if hasattr(test_class_obj, 'use_dynamic_create_class' ) and test_class_obj.use_dynamic_create_class: @@ -205,7 +249,7 @@ def create_test_class(func_globals, for dy_class in dynamic_classes: cls_name = "{0}_{1}".format(dy_class[0], str(test_type)) attr_dict = dy_class[1] - attr_dict['in_type'] = test_type + attr_dict['in_type'] = type_dict_str_to_numpy[test_type] func_globals[cls_name] = type(cls_name, (base_class, ), attr_dict) record_op_test(op_name, test_type) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py index 6c58c7ccf2cc01b1ccd2b2828566e6c4fb67de8b..7f8f5d6bc747b31473827f590f773dfeb7a95a7e 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py @@ -24,92 +24,103 @@ from op_test_xpu import OpTest, XPUOpTest import paddle from paddle.fluid import Program, program_guard - -class TestClipOp(XPUOpTest): - def set_xpu(self): - self.__class__.use_xpu = True - self.place = paddle.XPUPlace(0) - - def setUp(self): - self.set_xpu() - self.max_relative_error = 0.006 - - self.inputs = {} - self.initTestCase() - - self.op_type = "clip" - self.attrs = {} - self.attrs['min'] = self.min - self.attrs['max'] = self.max - if 'Min' in self.inputs: - min_v = self.inputs['Min'] - else: - min_v = self.attrs['min'] - - if 'Max' in self.inputs: - max_v = self.inputs['Max'] - else: - max_v = self.attrs['max'] - - input = np.random.random(self.shape).astype("float32") - input[np.abs(input - min_v) < self.max_relative_error] = 0.5 - input[np.abs(input - max_v) < self.max_relative_error] = 0.5 - self.inputs['X'] = input - self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)} - - def test_check_output(self): - paddle.enable_static() - self.check_output_with_place(self.place) - paddle.disable_static() - - def test_check_grad_normal(self): - paddle.enable_static() - self.check_grad_with_place(self.place, ['X'], 'Out') - paddle.disable_static() - - def initTestCase(self): - self.shape = (4, 10, 10) - self.max = 0.8 - self.min = 0.3 - self.inputs['Max'] = np.array([0.8]).astype('float32') - self.inputs['Min'] = np.array([0.1]).astype('float32') - - -class TestCase1(TestClipOp): - def initTestCase(self): - self.shape = (8, 16, 8) - self.max = 0.7 - self.min = 0.0 - - -class TestCase2(TestClipOp): - def initTestCase(self): - self.shape = (8, 16) - self.max = 1.0 - self.min = 0.0 - - -class TestCase3(TestClipOp): - def initTestCase(self): - self.shape = (4, 8, 16) - self.max = 0.7 - self.min = 0.2 - - -class TestCase4(TestClipOp): - def initTestCase(self): - self.shape = (4, 8, 8) - self.max = 0.7 - self.min = 0.2 - self.inputs['Max'] = np.array([0.8]).astype('float32') - self.inputs['Min'] = np.array([0.3]).astype('float32') - - -class TestCase5(TestClipOp): - def initTestCase(self): - self.shape = (4, 8, 16) - self.max = 0.5 - self.min = 0.5 +import op_test +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + + +class XPUTestClipOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'clip' + self.use_dynamic_create_class = False + + class TestClipOp(XPUOpTest): + def setUp(self): + self.init_dtype() + self.set_xpu() + self.op_type = "clip" + self.place = paddle.XPUPlace(0) + self.inputs = {} + self.init_data() + self.set_attrs() + self.set_inputs() + self.outputs = { + 'Out': np.clip(self.inputs['X'], self.min_v, self.max_v) + } + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = True + self.__class__.op_type = self.dtype + + def init_data(self): + self.shape = (4, 10, 10) + self.max = 0.8 + self.min = 0.3 + + def set_inputs(self): + if 'Min' in self.inputs: + min_v = self.inputs['Min'] + else: + min_v = self.attrs['min'] + + if 'Max' in self.inputs: + max_v = self.inputs['Max'] + else: + max_v = self.attrs['max'] + + self.min_v = min_v + self.max_v = max_v + self.max_relative_error = 0.006 + input = np.random.random(self.shape).astype("float32") + input[np.abs(input - min_v) < self.max_relative_error] = 0.5 + input[np.abs(input - max_v) < self.max_relative_error] = 0.5 + self.inputs['X'] = input + + def set_attrs(self): + self.attrs = {} + self.attrs['min'] = self.min + self.attrs['max'] = self.max + + def init_dtype(self): + self.dtype = self.in_type + + def test_check_output(self): + paddle.enable_static() + self.check_output_with_place(self.place) + paddle.disable_static() + + class TestClipOp1(TestClipOp): + def init_data(self): + self.shape = (8, 16, 8) + self.max = 0.7 + self.min = 0.0 + + class TestClipOp2(TestClipOp): + def init_data(self): + self.shape = (8, 16) + self.max = 1.0 + self.min = 0.0 + + class TestClipOp3(TestClipOp): + def init_data(self): + self.shape = (4, 8, 16) + self.max = 0.7 + self.min = 0.2 + + class TestClipOp4(TestClipOp): + def init_data(self): + self.shape = (4, 8, 8) + self.max = 0.7 + self.min = 0.2 + self.inputs['Max'] = np.array([0.8]).astype('float32') + self.inputs['Min'] = np.array([0.3]).astype('float32') + + class TestClipOp5(TestClipOp): + def init_data(self): + self.shape = (4, 8, 16) + self.max = 0.5 + self.min = 0.5 class TestClipOpError(unittest.TestCase): @@ -212,5 +223,9 @@ class TestInplaceClipAPI(TestClipAPI): return x.clip_(min, max) +support_types = get_xpu_op_support_types('clip') +for stype in support_types: + create_test_class(globals(), XPUTestClipOp, stype) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_gather_nd_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_gather_nd_op_xpu.py index 0f9751cec4d9286a46df00b174f4dfb9e21d5076..68854edb0ebb6cb317d18faf2fcc46ff31d9d18b 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_gather_nd_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_gather_nd_op_xpu.py @@ -18,251 +18,140 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test import OpTest -from op_test_xpu import XPUOpTest -import paddle.fluid as fluid -import paddle - - -def gather_nd_grad(x, index): - dout_shape = index.shape[:-1] + x.shape[index.shape[-1]:] - numel = 1 - for i in dout_shape: - numel = numel * i - dout = np.full(dout_shape, 1. / numel) - dx = np.full_like(x, 0) - - index = tuple(index.reshape(-1, index.shape[-1]).T) - np.add.at(dx, index, dout) - - return dx - - -def test_class1(op_type, typename): - class TestGatherNdOpWithEmptyIndex(XPUOpTest): - #Index has empty element, which means copy entire tensor - - def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" - xnp = np.random.random((5, 20)).astype(typename) - self.inputs = { - 'X': xnp, - 'Index': np.array([[], []]).astype("int32") - } - self.outputs = { - 'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) - } - - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_1".format(op_type, typename) - TestGatherNdOpWithEmptyIndex.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithEmptyIndex - - -def test_class2(op_type, typename): - class TestGatherNdOpWithIndex1(OpTest): - def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" - xnp = np.random.random((5, 20)).astype(typename) - self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} - self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_2".format(op_type, typename) - TestGatherNdOpWithIndex1.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithIndex1 - - -def test_class3(op_type, typename): - class TestGatherNdOpWithLowIndex(OpTest): - #Index has low rank, X has high rank - - def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" - xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) - index = np.array([[1], [2]]).astype("int64") - - self.inputs = {'X': xnp, 'Index': index} - self.outputs = {'Out': xnp[tuple(index.T)]} - self.x_grad = gather_nd_grad(xnp, index) - - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass +import paddle +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper - cls_name = "{0}_{1}_3".format(op_type, typename) - TestGatherNdOpWithLowIndex.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithLowIndex +paddle.enable_static() -def test_class4(op_type, typename): - class TestGatherNdOpIndex1(OpTest): - #Index has low rank, X has high rank +class XPUTestGatherNd(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'gather_nd' + class XPUTestGatherNdBase(XPUOpTest): def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) self.op_type = "gather_nd" - xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) - index = np.array([1, 2]).astype("int64") - - self.inputs = {'X': xnp, 'Index': index} - - self.outputs = {'Out': xnp[tuple(index.T)]} - - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_4".format(op_type, typename) - TestGatherNdOpIndex1.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpIndex1 - - -def test_class5(op_type, typename): - class TestGatherNdOpWithSameIndexAsX(OpTest): - #Index has same rank as X's rank - - def setUp(self): - self.set_xpu() + self.dtype = self.in_type + self.__class__.no_need_check_grad = True self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" - xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) - index = np.array([[1, 1], [2, 1]]).astype("int64") - - self.inputs = {'X': xnp, 'Index': index} - self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22] + self.init_data() - def set_xpu(self): - self.__class__.use_xpu = True + self.inputs = {'X': self.xnp, 'Index': self.inp} + self.outputs = {'Out': self.output, } def test_check_output(self): self.check_output_with_place(self.place) - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_5".format(op_type, typename) - TestGatherNdOpWithSameIndexAsX.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithSameIndexAsX - - -def test_class6(op_type, typename): - class TestGatherNdOpWithHighRankSame(OpTest): - #Both Index and X have high rank, and Rank(Index) = Rank(X) - - def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" + def init_data(self): + self.xnp = np.random.random((5, 20)).astype(self.in_type) + self.inp = np.array([[], []]).astype("int32") + self.output = np.vstack( + (self.xnp[np.newaxis, :], self.xnp[np.newaxis, :])) + + class XPUTestGatherNdOpWithEmptyIndex1(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.random((5, 20)).astype(self.in_type) + self.inp = np.array([[], []]).astype("int32") + self.output = np.vstack( + (self.xnp[np.newaxis, :], self.xnp[np.newaxis, :])) + + class XPUTestGatherNdOpWithEmptyIndex2(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.random((5, 20)).astype(self.in_type) + self.inp = np.array([[], []]).astype("int64") + self.output = np.vstack( + (self.xnp[np.newaxis, :], self.xnp[np.newaxis, :])) + + class XPUTestGatherNdOpWithIndex1(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.random((5, 20)).astype(self.in_type) + self.inp = np.array([1]).astype("int32") + self.output = self.xnp[self.inp] + + class XPUTestGatherNdOpWithIndex2(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.random((5, 20)).astype(self.in_type) + self.inp = np.array([1]).astype("int64") + self.output = self.xnp[self.inp] + + class XPUTestGatherNdOpWithLowIndex1(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([[1], [2]]).astype("int32") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpWithLowIndex2(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([1, 2]).astype("int64") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpWithHighRankSame1(XPUTestGatherNdBase): + def init_data(self): shape = (5, 2, 3, 1, 10) - xnp = np.random.rand(*shape).astype(typename) - index = np.vstack([np.random.randint( - 0, s, size=2) for s in shape]).T - - self.inputs = {'X': xnp, 'Index': index.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)]} - - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_6".format(op_type, typename) - TestGatherNdOpWithHighRankSame.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithHighRankSame - + self.xnp = np.random.rand(*shape).astype(self.in_type) + self.inp = np.vstack( + [np.random.randint( + 0, s, size=2) for s in shape]).T.astype("int32") + self.output = self.xnp[tuple(self.inp.T)] -def test_class7(op_type, typename): - class TestGatherNdOpWithHighRankDiff(OpTest): - #Both Index and X have high rank, Rank(Index) < Rank(X) + class XPUTestGatherNdOpWithHighRankSame2(XPUTestGatherNdBase): + def init_data(self): + shape = (5, 2, 3, 1, 10) + self.xnp = np.random.rand(*shape).astype(self.in_type) + self.inp = np.vstack( + [np.random.randint( + 0, s, size=2) for s in shape]).T.astype("int64") + self.output = self.xnp[tuple(self.inp.T)] - def setUp(self): - self.set_xpu() - self.place = paddle.XPUPlace(0) - self.op_type = "gather_nd" + class XPUTestGatherNdOpWithHighRankDiff1(XPUTestGatherNdBase): + def init_data(self): shape = (2, 3, 4, 1, 10) - xnp = np.random.rand(*shape).astype(typename) - index = np.vstack( + self.xnp = np.random.rand(*shape).astype(self.in_type) + self.inp = np.vstack( [np.random.randint( - 0, s, size=200) for s in shape]).T - index_re = index.reshape([20, 5, 2, 5]) - - self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} + 0, s, size=200) for s in shape]).T.astype("int32") + self.output = self.xnp[tuple(self.inp.T)] - def set_xpu(self): - self.__class__.use_xpu = True - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - pass - - cls_name = "{0}_{1}_7".format(op_type, typename) - TestGatherNdOpWithHighRankDiff.__name__ = cls_name - globals()[cls_name] = TestGatherNdOpWithHighRankDiff - - -class TestGatherNdAPI(unittest.TestCase): - def test_imperative(self): - paddle.disable_static() - input_1 = np.array([[1, 2], [3, 4], [5, 6]]) - index_1 = np.array([[1]]) - input = fluid.dygraph.to_variable(input_1) - index = fluid.dygraph.to_variable(index_1) - output = paddle.fluid.layers.gather(input, index) - output_np = output.numpy() - expected_output = np.array([3, 4]) - self.assertTrue(np.allclose(output_np, expected_output)) - paddle.enable_static() - - -for _typename in {'float32', 'int', 'int64'}: - test_class1('gather_nd', _typename) - test_class2('gather_nd', _typename) - test_class3('gather_nd', _typename) - test_class4('gather_nd', _typename) - test_class5('gather_nd', _typename) - test_class6('gather_nd', _typename) - test_class7('gather_nd', _typename) + class XPUTestGatherNdOpWithHighRankDiff2(XPUTestGatherNdBase): + def init_data(self): + shape = (2, 3, 4, 1, 10) + self.xnp = np.random.rand(*shape).astype(self.in_type) + self.inp = np.vstack( + [np.random.randint( + 0, s, size=200) for s in shape]).T.astype("int64") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpWithSameIndexAsX1(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([[1, 1], [2, 1]]).astype("int32") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpWithSameIndexAsX2(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([[1, 1], [2, 1]]).astype("int64") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpIndex1(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([1, 2]).astype("int32") + self.output = self.xnp[tuple(self.inp.T)] + + class XPUTestGatherNdOpIndex2(XPUTestGatherNdBase): + def init_data(self): + self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type) + self.inp = np.array([1, 2]).astype("int64") + self.output = self.xnp[tuple(self.inp.T)] + + +support_types = get_xpu_op_support_types('gather_nd') +for stype in support_types: + create_test_class(globals(), XPUTestGatherNd, stype) if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_refactor_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_refactor_op_xpu.py index a1eb0af2bc978dd46b9f25b81f972669d7b93d94..e7ee89c567f420d303e81f0b2b8f922d75763c78 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_refactor_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_refactor_op_xpu.py @@ -69,7 +69,7 @@ class XPUTestArgsortOp1(XPUOpTestWrapper): self.descending = False if not hasattr( self, 'init_descending') else self.init_descending - if self.in_type == 'float32': + if self.in_type == np.float32: self.x = np.random.random(self.input_shape).astype(self.dtype) else: self.x = np.random.randint( @@ -118,7 +118,7 @@ class XPUTestArgsortOp2(XPUOpTestWrapper): self.init_axis() self.init_direction() - if self.in_type == 'float32': + if self.in_type == np.float32: self.x = np.random.random(self.input_shape).astype(self.dtype) else: self.x = np.random.randint( diff --git a/python/paddle/fluid/tests/unittests/xpu/test_scale_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_scale_op_xpu.py index 761e5c2243c65927c60c25a55a908829315e786d..b27eefb6a166f00ba3c6896b0da6b278d73b92f9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_scale_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_scale_op_xpu.py @@ -18,54 +18,78 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test_xpu import XPUOpTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.op import Operator -import paddle -from paddle.static import Program, program_guard - - -class TestXPUScaleOp(XPUOpTest): - def setUp(self): - self.op_type = "scale" - self.init_type() - self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} - self.attrs = {'scale': -2.3, 'use_xpu': True} - self.outputs = { - 'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) - } - - def init_type(self): - self.dtype = np.float32 - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid import compiler, Program, program_guard -# class TestXPUScaleOpInt64(TestXPUScaleOp): -# def init_type(self): -# self.dtype = np.int64 - - -class TestScaleFp16Op(TestXPUScaleOp): - def init_dtype_type(self): - self.dtype = np.float16 - - def test_check_output(self): - place = core.XPUPlace(0) - self.check_output_with_place(place, atol=0.002) - - def test_check_grad(self): - place = core.XPUPlace(0) - self.check_grad_with_place(place, ["X"], "Out", max_relative_error=0.05) +import op_test +from op_test import OpTest, skip_check_grad_ci +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + + +class XPUTestScaleOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'scale' + self.use_dynamic_create_class = False + + class TestScaleOp(XPUOpTest): + def setUp(self): + self.init_dtype() + self.set_xpu() + self.op_type = "scale" + self.place = paddle.XPUPlace(0) + self.set_inputs() + self.set_attrs() + self.outputs = { + 'Out': self.inputs['X'] * self.dtype(self.attrs['scale']) + } + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = True + self.__class__.op_type = self.dtype + + def set_inputs(self): + self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} + + def init_dtype(self): + if "float16" == self.in_type: + self.dtype = np.float16 + if "float32" == self.in_type: + self.dtype = np.float32 + if "int64" == self.in_type: + self.dtype = np.int64 + + def set_attrs(self): + self.attrs = {'scale': -2.3} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + class TestScaleOp1(TestScaleOp): + def set_attrs(self): + self.attrs = {'scale': 3.5} + + class TestScaleOp2(TestScaleOp): + def set_attrs(self): + self.attrs = {'scale': 6.77} + + class TestScaleOp3(TestScaleOp): + def set_attrs(self): + self.attrs = {'scale': -9.19} + + class TestScaleOp4(TestScaleOp): + def set_attrs(self): + self.attrs = {'scale': 0.0} + + class TestScaleOp5(TestScaleOp): + def set_attrs(self): + self.attrs = {'scale': -0.003} class TestScaleApiStatic(unittest.TestCase): @@ -108,5 +132,9 @@ class TestScaleInplaceApiDygraph(TestScaleApiDygraph): return x.scale_(scale, bias) +support_types = get_xpu_op_support_types('scale') +for stype in support_types: + create_test_class(globals(), XPUTestScaleOp, stype) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py index 9cb31d4270552d23435a6bebc71ae6ef208b204d..1aac42f2d63a117f63e20b151408ee4c91216465 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py @@ -19,251 +19,255 @@ import numpy as np import sys sys.path.append("..") from op_test_xpu import OpTest, XPUOpTest -from op_test import skip_check_grad_ci import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard from paddle.fluid.framework import convert_np_dtype_to_dtype_ +from paddle.fluid import compiler, Program, program_guard, core +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + from scipy.special import logit from scipy.special import expit paddle.enable_static() -class TestSigmoidCrossEntropyWithLogitsOp1(XPUOpTest): - """Test sigmoid_cross_entropy_with_logit_op with binary label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = 64 - num_classes = 20 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, (batch_size, num_classes)) - .astype(self.dtype)), - 'Label': np.random.randint(0, 2, (batch_size, num_classes)) - .astype(self.dtype) - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - - def test_check_output(self): - self.check_output_with_place(self.place) - - def test_check_grad(self): - self.check_grad_with_place(self.place, ['X'], 'Out') - - def set_xpu(self): - self.__class__.use_xpu = True - self.place = paddle.XPUPlace(0) - - def init_dtype(self): - self.dtype = np.float32 - - -class TestSigmoidCrossEntropyWithLogitsOp2( - TestSigmoidCrossEntropyWithLogitsOp1): - """Test sigmoid_cross_entropy_with_logit_op with probabalistic label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = 64 - num_classes = 20 - ignore_index = -1 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, (batch_size, num_classes)) - .astype(self.dtype)), - 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) - .astype(self.dtype) - } - self.attrs = {'ignore_index': ignore_index, } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - out = -term1 - term2 - out[np.where(self.inputs['Label'] == ignore_index)] = 0 - self.outputs = {'Out': out} - - -class TestSigmoidCrossEntropyWithLogitsOp3( - TestSigmoidCrossEntropyWithLogitsOp1): - """Test sigmoid_cross_entropy_with_logit_op with probabalistic label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = 64 - num_classes = 20 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, (batch_size, num_classes)) - .astype(self.dtype)), - 'Label': np.random.uniform(0, 1, (batch_size, num_classes)) - .astype(self.dtype) - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - - -class TestSigmoidCrossEntropyWithLogitsOp4( - TestSigmoidCrossEntropyWithLogitsOp1): - """Test sigmoid_cross_entropy_with_logit_op with probabalistic label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = 64 - num_classes = 20 - ignore_index = -1 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, (batch_size, num_classes)) - .astype(self.dtype)), - 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) - .astype(self.dtype) - } - self.attrs = {'ignore_index': ignore_index, 'normalize': True} - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - out = -term1 - term2 - out[np.where(self.inputs['Label'] == ignore_index)] = 0 - if self.attrs['normalize']: - out = out / float( - np.where(self.inputs['Label'] != ignore_index)[0].size) - self.outputs = {'Out': out} - - -class TestSigmoidCrossEntropyWithLogitsOp5( - TestSigmoidCrossEntropyWithLogitsOp1): - """Test sigmoid_cross_entropy_with_logit_op with probabalistic label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = [10, 10] - num_classes = 20 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype(self.dtype)), - 'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype(self.dtype) - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - - -class TestSigmoidCrossEntropyWithLogitsNorm( - TestSigmoidCrossEntropyWithLogitsOp1): - """Test sigmoid_cross_entropy_with_logit_op with probabalistic label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = [10, 10] - num_classes = 20 - ignore_index = -1 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype(self.dtype)), - 'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes])) - .astype(self.dtype) - } - self.attrs = {'ignore_index': ignore_index, 'normalize': True} - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - out = -term1 - term2 - out[np.where(self.inputs['Label'] == ignore_index)] = 0 - if self.attrs['normalize']: - out = out / float( - np.where(self.inputs['Label'] != ignore_index)[0].size) - self.outputs = {'Out': out} - - -class TestSigmoidCrossEntropyWithLogitsOp6( - TestSigmoidCrossEntropyWithLogitsOp1): +class XPUTestSigmoidCrossEntropyWithLogitsOp(XPUOpTestWrapper): """Test sigmoid_cross_entropy_with_logit_op with binary label """ - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - self.set_xpu() - self.init_dtype() - - batch_size = [10, 10] - num_classes = 20 - self.inputs = { - 'X': logit( + def __init__(self): + self.op_name = "sigmoid_cross_entropy_with_logits" + self.use_dynamic_create_class = False + + class TestSigmoidCrossEntropyWithLogitsOp(XPUOpTest): + def setUp(self): + self.set_xpu() + self.op_type = "sigmoid_cross_entropy_with_logits" + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.set_inputs() + self.init_dtype() + self.set_output() + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def set_inputs(self): + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(0, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'num_classes': num_classes, 'batch_size': batch_size} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.op_type = self.in_type + self.place = paddle.XPUPlace(0) + + def init_dtype(self): + self.dtype = self.in_type + + class TestSigmoidCrossEntropyWithLogitsOp2( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def set_inputs(self): + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.ignore_index = ignore_index + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == self.ignore_index)] = 0 + self.outputs = {'Out': out} + + class TestSigmoidCrossEntropyWithLogitsOp3( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def set_inputs(self): + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'num_classes': num_classes, 'batch_size': batch_size} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + class TestSigmoidCrossEntropyWithLogitsOp4( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def set_inputs(self): + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.ignore_index = ignore_index + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index, 'normalize': True} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == self.ignore_index)] = 0 + if self.attrs['normalize']: + out = out / float( + np.where(self.inputs['Label'] != self.ignore_index)[0].size) + self.outputs = {'Out': out} + + class TestSigmoidCrossEntropyWithLogitsOp5( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def set_inputs(self): + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype(self.dtype)), - 'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes])) - .astype(self.dtype) - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - + .astype(self.dtype) + } + self.attrs = {'num_classes': num_classes, 'batch_size': batch_size} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + class TestSigmoidCrossEntropyWithLogitsOp6( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def set_inputs(self): + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': + np.random.randint(0, 2, tuple(batch_size + [num_classes])) + .astype(self.dtype) + } + self.attrs = {'num_classes': num_classes, 'batch_size': batch_size} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + class TestSigmoidCrossEntropyWithLogitsNorm( + TestSigmoidCrossEntropyWithLogitsOp): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def set_inputs(self): + batch_size = [10, 10] + num_classes = 20 + ignore_index = -1 + self.ignore_index = ignore_index + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': + np.random.randint(-1, 2, tuple(batch_size + [num_classes])) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index, 'normalize': True} + + def set_output(self): + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == self.ignore_index)] = 0 + if self.attrs['normalize']: + out = out / float( + np.where(self.inputs['Label'] != self.ignore_index)[0].size) + self.outputs = {'Out': out} + + +support_types = get_xpu_op_support_types('sigmoid_cross_entropy_with_logits') +for stype in support_types: + create_test_class(globals(), XPUTestSigmoidCrossEntropyWithLogitsOp, stype) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 26da7ae2adfaceaffe90aa203ec78bd0edb14b61..d14cf11c8dd7eaea2482e7a043c76530fc6fc7d7 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1,5 +1,5 @@ - backward_api : matmul_grad - forward : matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out) + forward : matmul (const Tensor& x, const Tensor& y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out) args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) output : Tensor(x_grad), Tensor(y_grad) infer_meta : diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index 09f22c33a84b5e845be29a585d1fafa117dde20e..eb58764f7fce3aa6a77126c84c54bf4c61b7487b 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -179,8 +179,8 @@ for API_FILE in ${API_FILES[*]}; do echo_line="You must have one RD (Xreki,luotao1,zhhsplendid) approval for ${API_FILE}, which manages the underlying code for PaddlePaddle.\n" check_approval 1 12538138 6836917 7913861 else - echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1) approval for ${API_FILE}, which manages the underlying code for fluid.\n" - check_approval 1 46782768 12538138 6836917 22561442 6888866 + echo_line="You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93) approval for ${API_FILE}, which manages the underlying code for fluid.\n" + check_approval 1 46782768 12538138 6836917 22561442 6888866 16605440 fi fi done @@ -288,8 +288,8 @@ fi HAS_OPERATORBASE_FLAG=`git diff -U0 --diff-filter=A upstream/$BRANCH | grep -E "public[[:space:]]+.*OperatorBase" || true` if [ "${HAS_OPERATORBASE_FLAG}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then - echo_line="In order to support dynamic graph, all ops are not recommended to inherit OperatorBase. Please use OperatorWithKernel instead.\nYou must have one RD (phlrain (Recommend), luotao1, lanxianghit or XiaoguangHu01) approval for the inherit of OperatorBase.\nYou inherit the OperatorBase class. The corresponding lines are as follows:\n${HAS_OPERATORBASE_FLAG}" - check_approval 1 43953930 6836917 47554610 46782768 + echo_line="In order to support dynamic graph, all ops are not recommended to inherit OperatorBase. Please use OperatorWithKernel instead.\nYou must have one RD (phlrain (Recommend), luotao1, lanxianghit, XiaoguangHu01, or qili93) approval for the inherit of OperatorBase.\nYou inherit the OperatorBase class. The corresponding lines are as follows:\n${HAS_OPERATORBASE_FLAG}" + check_approval 1 43953930 6836917 47554610 46782768 16605440 fi HAS_INPLACE_TESTS=`git diff -U0 upstream/$BRANCH |grep "+" |grep -E "inplace_atol[[:space:]]*=.*" || true`