diff --git a/paddle/fluid/eager/api/utils/hook_utils.cc b/paddle/fluid/eager/api/utils/hook_utils.cc index c7927716300528fdfa571de720ce12e7246b5f1d..9abd7be49d44cbab4b3482961df461dd7164328f 100644 --- a/paddle/fluid/eager/api/utils/hook_utils.cc +++ b/paddle/fluid/eager/api/utils/hook_utils.cc @@ -52,49 +52,44 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, } } -static void RetainGradForRegularNode( - const paddle::experimental::Tensor& tensor) { - AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor); - if (meta->RetainGrads()) { +void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { + if (IsLeafTensor(tensor)) { + // Leaf tensor's grad will always be retained + // Refer to implementation of AccumulationNode for more details return; } else { - meta->SetRetainGrads(true); - } + AutogradMeta* meta = EagerUtils::unsafe_autograd_meta(tensor); + if (meta->RetainGrads()) { + return; + } else { + meta->SetRetainGrads(true); + } - std::weak_ptr weak_grad_tensor = - meta->WeakGrad(); + std::weak_ptr weak_grad_tensor = + meta->WeakGrad(); - // Define Hook - auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) { - if (!weak_grad_tensor.expired()) { - auto grad_tensor = weak_grad_tensor.lock(); - if (t.defined()) { - VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name(); - // Simply Copy impl() to grad_tensor - grad_tensor->set_impl(t.impl()); - return *grad_tensor.get(); + // Define Hook + auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) { + if (!weak_grad_tensor.expired()) { + auto grad_tensor = weak_grad_tensor.lock(); + if (t.defined()) { + VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name(); + // Simply Copy impl() to grad_tensor + grad_tensor->set_impl(t.impl()); + return *grad_tensor.get(); + } else { + VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; + return paddle::experimental::Tensor(); + } } else { VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; return paddle::experimental::Tensor(); } - } else { - VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook"; - return paddle::experimental::Tensor(); - } - }; + }; - // Append to GradientHooks - RegisterGradientHookForTensor(tensor, - std::make_shared(hook)); -} - -void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { - if (IsLeafTensor(tensor)) { - // Leaf tensor's grad will always be retained - // Refer to implementation of AccumulationNode for more details - return; - } else { - RetainGradForRegularNode(tensor); + // Append to GradientHooks + RegisterGradientHookForTensor(tensor, + std::make_shared(hook)); } } diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index a8e0ed7a41a043e12332ad347f673a6c27e5f1ec..102fad56373803a19f07afc7dda72e9704ac83d5 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1156,11 +1156,13 @@ static std::string GenerateGradNodeCreationContent( grad_node_creation_str += paddle::string::Sprintf( SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position); - const char* SET_HISTORY_TEMPLATE = - " egr::EagerUtils::SetHistory(&%s, grad_node);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); - + // Intermediate Tensor does not require SetHistory + if (!output.intermediate()) { + const char* SET_HISTORY_TEMPLATE = + " egr::EagerUtils::SetHistory(&%s, grad_node);\n"; + grad_node_creation_str += + paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); + } const char* SET_GRAD_IN_META_TEMPLATE = " grad_node->SetGradInMeta(&%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( @@ -1173,17 +1175,20 @@ static std::string GenerateGradNodeCreationContent( grad_node_creation_str += paddle::string::Sprintf( SET_OUT_RANK_TEMPLATE, output_autograd_name, output_position); - const char* SET_HISTORY_TEMPLATE = - " egr::EagerUtils::SetHistory(%s, grad_node);\n"; - grad_node_creation_str += - paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); - + // Intermediate Tensor does not require SetHistory + if (!output.intermediate()) { + const char* SET_HISTORY_TEMPLATE = + " egr::EagerUtils::SetHistory(%s, grad_node);\n"; + grad_node_creation_str += + paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); + } const char* SET_GRAD_IN_META_TEMPLATE = " grad_node->SetGradInMeta(%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position); } + // Intermediate Tensor does not require CheckAndRetainGrad if (!output.intermediate()) { VLOG(6) << "Generated Call RetainGradForTensor"; const char* RETAIN_GRAD_TEMPLATE = diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index c6e56e34627a52bc19df7e8d87371811fcec8697..02183e2ca5ce9f0996017eb7df59ee716b0f1ae2 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -24,6 +24,17 @@ core_ops_args_info = {} core_ops_args_type_info = {} +yaml_types_mapping = { + 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ + 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ + 'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \ + 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', + 'Tensor' : 'Tensor', + 'Tensor[]' : 'std::vector', + 'Tensor[Tensor[]]' : 'std::vector>' +} + + def ParseArguments(): parser = argparse.ArgumentParser( description='Eager Code Generator Args Parser') @@ -59,7 +70,9 @@ def IsPlainTensorType(string): def IsVectorTensorType(string): - vector_tensor_types = ['list(Tensor)'] + vector_tensor_types = [ + 'std::vector>', 'std::vector' + ] if string in vector_tensor_types: return True return False @@ -180,6 +193,9 @@ def ParseYamlArgs(string): arg_name = m.group(3).split("=")[0].strip() default_value = m.group(3).split("=")[1].strip() if len( m.group(3).split("=")) > 1 else None + + assert arg_type in yaml_types_mapping.keys() + arg_type = yaml_types_mapping[arg_type] if "Tensor" in arg_type: assert default_value is None inputs_list.append([arg_name, arg_type, i]) @@ -219,6 +235,10 @@ def ParseYamlReturnsWithName(string): m = re.search(pattern, ret) ret_type = m.group(1) ret_name = m.group(2) + + assert ret_type in yaml_types_mapping.keys() + ret_type = yaml_types_mapping[ret_type] + assert "Tensor" in ret_type returns_list.append([ret_name, ret_type, i]) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 7073ca8f0527ba8237da734db0c8724baa2a49ec..356fdcaf054277085be57491eb1525beeac8d792 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -221,10 +221,11 @@ void RunBackward(const std::vector& tensors, << " 's name is: " << grad_output_tensor.name(); auto* next_node = next_node_shared.get(); - if (!node_input_buffers_dict.count(next_node)) { - node_input_buffers_dict[next_node] = - std::make_unique(next_node->InputMeta()); + const auto& input_meta = next_node->InputMeta(); + auto grad_tensor_holder = + std::make_unique(input_meta); + node_input_buffers_dict[next_node] = std::move(grad_tensor_holder); } VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first << ", rank: " << edge_rank.second; diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 35416281f188892ec11413a19abad9b3e5c29e76..b1189106b8f871ab618972ad93e9812ce443e55d 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -244,7 +244,7 @@ GradNodeBase::ApplyGradientHooks( if (!out.defined() || !out.initialized()) { out = (*hook)(tensors[slot_id][rank]); } else { - // If more than one hook is registered, the input to the next hook func + // If more than one hook is registered, the input to the next hook func // should be the output of the previous hook out = (*hook)(out); } diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index a7e5931f1f9bc66006fb1a37836be1eda371953e..39861c80522a920502fff91177256a4b7abf6dc6 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -122,12 +122,21 @@ paddle::experimental::Tensor* EagerUtils::mutable_grad( void EagerUtils::SetHistory(std::vector* autograd_metas, const std::shared_ptr& grad_node) { for (const auto& autograd_meta : *autograd_metas) { + if (dynamic_cast(autograd_meta->GradNode())) { + VLOG(6) << "Warning: Reseting GradNodeAccumulation for leaf tensor is " + "detected"; + } autograd_meta->SetGradNode(grad_node); } } void EagerUtils::SetHistory(AutogradMeta* autograd_meta, const std::shared_ptr& grad_node) { + if (dynamic_cast(autograd_meta->GradNode())) { + VLOG(6) + << "Warning: Reseting GradNodeAccumulation for leaf tensor is detected"; + } + autograd_meta->SetGradNode(grad_node); } diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index d9287b9a624d39c40cd63071ab08257a8526ce17..57fb68e80427afa56372bebb31ff5822135858b6 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { return var_types[0] == proto::VarType::SELECTED_ROWS; } + bool IsForInferShape() const override { return true; } + private: const InferShapeContext& ctx_; }; @@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor { } } else { auto* var = BOOST_GET_CONST(VarDesc*, var_); - return phi::make_ddim(var->GetShape()); + + return var->GetShape().empty() ? phi::make_ddim({0UL}) + : phi::make_ddim(var->GetShape()); } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 16718a316513e3574e9a7eb14ed50106c8b0dcb6..e33d4feb82a9e7a92c3dabea0ccc5fe370afda66 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { return ctx_.OutputVar(name)->IsType(); } + bool IsForInferShape() const override { return false; } + private: const ExecutionContext& ctx_; }; diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 93bc2c02d57cb7b57cf48d6f5c34a27a97637377..14997dd9610138e32a45ef17abc9276cd1dad172 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -125,6 +125,15 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } +#endif +#ifdef PADDLE_WITH_IPU + if (platform::is_ipu_place(expected_kernel_key.place_)) { + VLOG(3) << "pten missing IPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } #endif return phi::KernelKey(); } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 37214534f3c937bcf62bb34b51da2c934c566ced..0281fd917658ad0a2f6b22cefe02efec97870721 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -453,6 +453,23 @@ if(WITH_MKLDNN) download_int8_data_without_verify(${INT8_GOOGLENET_MODEL_DIR} "GoogleNet_int8_model.tar.gz" ) inference_analysis_api_int8_test_run_custom_warmup_batch_size(test_analyzer_int8_googlenet ${INT8_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} 10) + # mobilenetv3_large_x1_0 int8 + set(INT8_MOBILENETV3_LARGE_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv3_large") + set(INT8_MOBILENETV3_FILE_NAME "MobileNetV3_large_x1_0_infer.tar") + if (NOT EXISTS ${INT8_MOBILENETV3_LARGE_MODEL_DIR}/${INT8_MOBILENETV3_FILE_NAME}) + inference_download_and_uncompress_without_verify(${INT8_MOBILENETV3_LARGE_MODEL_DIR} "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/" ${INT8_MOBILENETV3_FILE_NAME}) + endif() + inference_analysis_test_run(test_analyzer_int8_mobilenetv3_large + COMMAND ${INT8_IMG_CLASS_TEST_APP} + ARGS --infer_model=${INT8_MOBILENETV3_LARGE_MODEL_DIR}/MobileNetV3_large_x1_0_infer + --infer_data=${IMAGENET_DATA_PATH} + --warmup_batch_size=50 + --batch_size=1 + --enable_int8=true + --cpu_num_threads=${CPU_NUM_THREADS_ON_CI} + --iterations=100 + --with_accuracy_layer=false) + ### BFLOAT16 tests # build test binary to be used in subsequent tests @@ -472,6 +489,17 @@ if(WITH_MKLDNN) # mobilenetv2 bfloat16 inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_mobilenetv2 ${BF16_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) + # mobilenetv3_large + inference_analysis_test_run(test_analyzer_bfloat16_mobilenetv3_large + COMMAND ${BF16_IMG_CLASS_TEST_APP} + ARGS --infer_model=${INT8_MOBILENETV3_LARGE_MODEL_DIR}/MobileNetV3_large_x1_0_infer + --infer_data=${IMAGENET_DATA_PATH} + --batch_size=1 + --enable_bf16=true + --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} + --iterations=100 + --with_accuracy_layer=false) + ### Object detection models set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_val_head_300.bin") set(INT8_OBJ_DETECT_TEST_APP "test_analyzer_int8_object_detection") @@ -739,6 +767,7 @@ if(WITH_MKLDNN) set_tests_properties(test_analyzer_quant_performance_benchmark PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_int8_mobilenetv2 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_int8_mobilenetv1 PROPERTIES TIMEOUT 120) + set_tests_properties(test_analyzer_int8_mobilenetv3_large PROPERTIES TIMEOUT 120) endif() set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120) diff --git a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc index 3b16b0d34fd4cb87879bb6ed585e72b48167ac2c..f267f0f28d685e51f0359a345c52fbbe4a49fa16 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc @@ -14,13 +14,19 @@ limitations under the License. */ #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/tests/api/tester_helper.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace inference { namespace analysis { void SetConfig(AnalysisConfig *cfg) { - cfg->SetModel(FLAGS_infer_model); + std::ifstream model_file(FLAGS_infer_model + "/__model__"); + if (model_file.good()) + cfg->SetModel(FLAGS_infer_model); + else + cfg->SetModel(FLAGS_infer_model + "/inference.pdmodel", + FLAGS_infer_model + "/inference.pdiparams"); cfg->DisableGpu(); cfg->SwitchIrOptim(); cfg->SwitchSpecifyInputNames(); @@ -38,7 +44,12 @@ TEST(Analyzer_bfloat16_image_classification, bfloat16) { // read data from file and prepare batches with test data std::vector> input_slots_all; SetInputs(&input_slots_all); - b_cfg.EnableMkldnnBfloat16(); + if (FLAGS_enable_bf16 && + platform::MayIUse(platform::cpu_isa_t::avx512_bf16)) { + b_cfg.EnableMkldnnBfloat16(); + } else { + FLAGS_enable_bf16 = false; + } CompareBFloat16AndAnalysis(&cfg, &b_cfg, input_slots_all); } diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc index 8f8b73044232a5cacfa3609e5f8e32ccf375d418..b07163b518b529e7ab01107e1f0d217443f574bd 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -22,7 +22,12 @@ namespace inference { namespace analysis { void SetConfig(AnalysisConfig *cfg) { - cfg->SetModel(FLAGS_infer_model); + std::ifstream model_file(FLAGS_infer_model + "/__model__"); + if (model_file.good()) + cfg->SetModel(FLAGS_infer_model); + else + cfg->SetModel(FLAGS_infer_model + "/inference.pdmodel", + FLAGS_infer_model + "/inference.pdiparams"); cfg->DisableGpu(); cfg->SwitchIrOptim(); cfg->SwitchSpecifyInputNames(); diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 637fa16e31ba7996713a6971c3a1802627811e7f..e63dfd14175b9955fbf5b6fdb0fb7904a330f264 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -213,15 +213,15 @@ std::shared_ptr> GetWarmupData( element_in_batch * 3 * 224 * 224, 3 * 224 * 224, static_cast(images.data.data()) + i * 3 * 224 * 224); - - std::copy_n(static_cast(test_data[batch][1].data.data()) + - element_in_batch, - 1, static_cast(labels.data.data()) + i); + if (FLAGS_with_accuracy_layer) + std::copy_n(static_cast(test_data[batch][1].data.data()) + + element_in_batch, + 1, static_cast(labels.data.data()) + i); } - - auto warmup_data = std::make_shared>(2); + auto warmup_data = std::make_shared>( + FLAGS_with_accuracy_layer ? 2 : 1); (*warmup_data)[0] = std::move(images); - (*warmup_data)[1] = std::move(labels); + if (FLAGS_with_accuracy_layer) (*warmup_data)[1] = std::move(labels); return warmup_data; } @@ -254,9 +254,13 @@ void SetInputs(std::vector> *inputs, } for (auto i = 0; i < iterations; i++) { auto images = image_reader.NextBatch(); - auto labels = label_reader.NextBatch(); - inputs->emplace_back( - std::vector{std::move(images), std::move(labels)}); + std::vector tmp_vec; + tmp_vec.push_back(std::move(images)); + if (FLAGS_with_accuracy_layer) { + auto labels = label_reader.NextBatch(); + tmp_vec.push_back(std::move(labels)); + } + inputs->push_back(std::move(tmp_vec)); } } @@ -825,7 +829,8 @@ void CompareQuantizedAndAnalysis( SummarizePerformance("FP32", sample_latency_fp32, "INT8", sample_latency_int8); - CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx); + if (FLAGS_with_accuracy_layer) + CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx); } void CompareBFloat16AndAnalysis( @@ -864,7 +869,8 @@ void CompareBFloat16AndAnalysis( SummarizePerformance("FP32", sample_latency_fp32, "BF16", sample_latency_bf16); - CompareAccuracy(bf16_outputs, analysis_outputs, compared_idx); + if (FLAGS_with_accuracy_layer) + CompareAccuracy(bf16_outputs, analysis_outputs, compared_idx); } void CompareAnalysisAndAnalysis( diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index e80df5f95bb4ab33a6c08cc646d0ef8311e38936..6157a3a925de51a9b65efbb2df9d5178132b1baf 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -18,6 +18,10 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_mean"; } }; +DELCARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor, + PT_INFER_META(phi::MeanRawInferMeta)); + REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, ops::ReduceMeanOpGradMaker, - ops::ReduceMeanOpGradMaker); + ops::ReduceMeanOpGradMaker, + ReduceMeanInferShapeFunctor); REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradOpBaseMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index bdab14a18a05ab3e0df1dbda57f3753033cfacb4..8ef0712dc7a757dfe91e48e7b0bb32f24840e02e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -16,6 +16,10 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class OpDesc; @@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_sum"; } }; +DELCARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor, + PT_INFER_META(phi::ReduceInferMetaBase)); + REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ops::ReduceSumVarTypeInference, ops::ReduceSumOpGradMaker, - ops::ReduceSumOpGradMaker); + ops::ReduceSumOpGradMaker, + ReduceSumInferShapeFunctor); REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, ops::ReduceSumDoubleOpGradMaker, ops::ReduceSumDoubleOpGradMaker, diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 0adf61d7ce3e5b5792b9dc65d5ac8f884dc81ea5..88ef1f3ea4aa4d8d827a810026575c20e596b4e7 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/selu_op.h" - #include #include #include @@ -127,9 +125,3 @@ REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, ops::SeluGradMaker, ops::SeluGradMaker); REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); -REGISTER_OP_CPU_KERNEL( - selu, ops::SeluKernel, - ops::SeluKernel); -REGISTER_OP_CPU_KERNEL( - selu_grad, ops::SeluGradKernel, - ops::SeluGradKernel); diff --git a/paddle/fluid/operators/selu_op.cu b/paddle/fluid/operators/selu_op.cu deleted file mode 100644 index fb3245ab7609ea9067709134a3713e9871dbb4d4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/selu_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/selu_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - selu, ops::SeluKernel, - ops::SeluKernel); -REGISTER_OP_CUDA_KERNEL( - selu_grad, ops::SeluGradKernel, - ops::SeluGradKernel); diff --git a/paddle/fluid/operators/selu_op.h b/paddle/fluid/operators/selu_op.h deleted file mode 100644 index b2fc834c42f65ff3521b6267ed2f32fabbab4e4d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/selu_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math.h" -#include "paddle/fluid/platform/for_range.h" - -namespace paddle { -namespace operators { - -template -struct SeluFunctor { - SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) - : x_data_ptr_(x_data_ptr), - alpha_(alpha), - scale_(scale), - y_data_ptr_(y_data_ptr) {} - - HOSTDEVICE void operator()(size_t idx) const { - T x_ele = x_data_ptr_[idx]; - if (x_ele <= 0) { - x_ele = alpha_ * real_exp(x_ele) - alpha_; - } - y_data_ptr_[idx] = scale_ * x_ele; - } - const T* x_data_ptr_; - const float alpha_; - const float scale_; - T* y_data_ptr_; -}; - -template -struct SeluGradFunctor { - SeluGradFunctor(const T* y_data_ptr, const T* dy_data_ptr, float alpha, - float scale, T* dx_data_ptr) - : y_data_ptr_(y_data_ptr), - dy_data_ptr_(dy_data_ptr), - alpha_(alpha), - scale_(scale), - la_(alpha * scale), - dx_data_ptr_(dx_data_ptr) {} - - HOSTDEVICE void operator()(size_t idx) const { - T y_ele = y_data_ptr_[idx]; - T dy_ele = dy_data_ptr_[idx]; - - float tmp = scale_; - if (y_ele <= 0) { - tmp = y_ele + la_; - } - dx_data_ptr_[idx] = dy_ele * tmp; - } - const T* y_data_ptr_; - const T* dy_data_ptr_; - const float alpha_; - const float scale_; - const float la_; - T* dx_data_ptr_; -}; - -template -class SeluKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using Tensor = framework::Tensor; - - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - - float alpha = context.Attr("alpha"); - float scale = context.Attr("scale"); - - auto out_ptr = out->mutable_data(context.GetPlace()); - - SeluFunctor functor(x->data(), alpha, scale, out_ptr); - - auto& dev_ctx = context.template device_context(); - size_t limit = static_cast(x->numel()); - platform::ForRange for_range(dev_ctx, limit); - for_range(functor); - } -}; - -template -class SeluGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using Tensor = framework::Tensor; - - auto* out = context.Input("Out"); - auto* dout = context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - - float alpha = context.Attr("alpha"); - float scale = context.Attr("scale"); - - auto dx_ptr = dx->mutable_data(context.GetPlace()); - - SeluGradFunctor functor(out->data(), dout->data(), alpha, scale, - dx_ptr); - - auto& dev_ctx = context.template device_context(); - size_t limit = static_cast(out->numel()); - platform::ForRange for_range(dev_ctx, limit); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/ipu_strategy.cc b/paddle/fluid/platform/device/ipu/ipu_strategy.cc index 943dfcc6cffb875fc3cebfc88e35adeaba47fd63..e806b0b30e4e03759847cc2e1838171020a064b1 100644 --- a/paddle/fluid/platform/device/ipu/ipu_strategy.cc +++ b/paddle/fluid/platform/device/ipu/ipu_strategy.cc @@ -120,121 +120,151 @@ IpuStrategy::IpuStrategy() { RegisterGetter(options_getter, options_type, #name, "string", \ [&]() { return popart_options.aliased_name; }) -#define ADD_POPART_ENUM_OPTION(name, EnumType) \ - ADD_POPART_ENUM_OPTION_ALIAS(name, name, EnumType) - -#define ADD_POPART_BOOL_OPTION(name) ADD_POPART_BOOL_OPTION_ALIAS(name, name) - -#define ADD_POPART_UINT64_OPTION(name) \ - ADD_POPART_UINT64_OPTION_ALIAS(name, name) - -#define ADD_POPART_DOUBLE_OPTION(name) \ - ADD_POPART_DOUBLE_OPTION_ALIAS(name, name) - -#define ADD_POPART_STRING_OPTION(name) \ - ADD_POPART_STRING_OPTION_ALIAS(name, name) - - ADD_POPART_ENUM_OPTION(autodiffSettings.stitchStrategy, - AutodiffStitchStrategy); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.transformContext, - BatchSerializationTransformContext); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.method, - BatchSerializationMethod); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.batchSchedule, - BatchSerializationBatchSchedule); - ADD_POPART_ENUM_OPTION(autoRecomputation, RecomputationType); - ADD_POPART_ENUM_OPTION(mergeVarUpdate, MergeVarUpdateType); - ADD_POPART_ENUM_OPTION(virtualGraphMode, VirtualGraphMode); - ADD_POPART_ENUM_OPTION(syntheticDataMode, SyntheticDataMode); - ADD_POPART_ENUM_OPTION(subgraphCopyingStrategy, SubgraphCopyingStrategy); - ADD_POPART_ENUM_OPTION(accumulationAndReplicationReductionType, - ReductionType); - ADD_POPART_ENUM_OPTION(meanAccumulationAndReplicationReductionStrategy, - MeanReductionStrategy); - - ADD_POPART_STRING_OPTION(logDir); - ADD_POPART_STRING_OPTION(cachePath); - ADD_POPART_STRING_OPTION(partialsTypeMatMuls); - ADD_POPART_STRING_OPTION(customCodeletCompileFlags); - ADD_POPART_STRING_OPTION(serializedPoprithmsShiftGraphsDir); - ADD_POPART_STRING_OPTION(kahnTieBreaker); - - ADD_POPART_UINT64_OPTION(executionPhaseSettings.phases); - ADD_POPART_UINT64_OPTION(executionPhaseSettings.stages); - ADD_POPART_UINT64_OPTION(batchSerializationSettings.factor); - ADD_POPART_UINT64_OPTION(firstDotOp); - ADD_POPART_UINT64_OPTION(finalDotOp); - ADD_POPART_UINT64_OPTION(numIOTiles); - ADD_POPART_UINT64_OPTION(mergeVarUpdateMemThreshold); - ADD_POPART_UINT64_OPTION(looseThresholdAtPeak); - ADD_POPART_UINT64_OPTION(accumulationFactor); - ADD_POPART_UINT64_OPTION(swapLimitScheduler); - ADD_POPART_UINT64_OPTION(globalReplicationFactor); - ADD_POPART_UINT64_OPTION(globalReplicaOffset); - ADD_POPART_UINT64_OPTION(defaultPrefetchBufferingDepth); - ADD_POPART_UINT64_OPTION(compilationProgressTotal); - ADD_POPART_UINT64_OPTION(transitiveClosureOptimizationThreshold); - - ADD_POPART_BOOL_OPTION(batchSerializationSettings.concatOnVirtualGraphChange); - ADD_POPART_BOOL_OPTION( + ADD_POPART_ENUM_OPTION_ALIAS(autodiff_settings.stitch_strategy, + autodiffSettings.stitchStrategy, + AutodiffStitchStrategy); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.transform_context, + batchSerializationSettings.transformContext, + BatchSerializationTransformContext); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.method, + batchSerializationSettings.method, + BatchSerializationMethod); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.batch_schedule, + batchSerializationSettings.batchSchedule, + BatchSerializationBatchSchedule); + ADD_POPART_ENUM_OPTION_ALIAS(auto_recomputation, autoRecomputation, + RecomputationType); + ADD_POPART_ENUM_OPTION_ALIAS(merge_var_update, mergeVarUpdate, + MergeVarUpdateType); + ADD_POPART_ENUM_OPTION_ALIAS(virtual_graph_mode, virtualGraphMode, + VirtualGraphMode); + ADD_POPART_ENUM_OPTION_ALIAS(synthetic_data_mode, syntheticDataMode, + SyntheticDataMode); + ADD_POPART_ENUM_OPTION_ALIAS(subgraph_copying_strategy, + subgraphCopyingStrategy, + SubgraphCopyingStrategy); + ADD_POPART_ENUM_OPTION_ALIAS(accumulation_and_replication_reduction_type, + accumulationAndReplicationReductionType, + ReductionType); + ADD_POPART_ENUM_OPTION_ALIAS( + mean_accumulation_and_replication_reduction_strategy, + meanAccumulationAndReplicationReductionStrategy, MeanReductionStrategy); + + ADD_POPART_STRING_OPTION_ALIAS(log_dir, logDir); + ADD_POPART_STRING_OPTION_ALIAS(cache_path, cachePath); + ADD_POPART_STRING_OPTION_ALIAS(partials_type_matmuls, partialsTypeMatMuls); + ADD_POPART_STRING_OPTION_ALIAS(custom_codelet_compile_flags, + customCodeletCompileFlags); + ADD_POPART_STRING_OPTION_ALIAS(serialized_poprithms_shift_graphs_dir, + serializedPoprithmsShiftGraphsDir); + ADD_POPART_STRING_OPTION_ALIAS(kahn_tie_breaker, kahnTieBreaker); + + ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.phases, + executionPhaseSettings.phases); + ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.stages, + executionPhaseSettings.stages); + ADD_POPART_UINT64_OPTION_ALIAS(batch_serialization_settings.factor, + batchSerializationSettings.factor); + ADD_POPART_UINT64_OPTION_ALIAS(first_dot_op, firstDotOp); + ADD_POPART_UINT64_OPTION_ALIAS(final_dot_op, finalDotOp); + ADD_POPART_UINT64_OPTION_ALIAS(num_io_tiles, numIOTiles); + ADD_POPART_UINT64_OPTION_ALIAS(merge_var_update_mem_threshold, + mergeVarUpdateMemThreshold); + ADD_POPART_UINT64_OPTION_ALIAS(loose_threshold_at_peak, looseThresholdAtPeak); + ADD_POPART_UINT64_OPTION_ALIAS(accumulation_factor, accumulationFactor); + ADD_POPART_UINT64_OPTION_ALIAS(swap_limit_scheduler, swapLimitScheduler); + ADD_POPART_UINT64_OPTION_ALIAS(global_replication_factor, + globalReplicationFactor); + ADD_POPART_UINT64_OPTION_ALIAS(global_replica_offset, globalReplicaOffset); + ADD_POPART_UINT64_OPTION_ALIAS(default_prefetch_buffering_depth, + defaultPrefetchBufferingDepth); + ADD_POPART_UINT64_OPTION_ALIAS(compilation_progress_total, + compilationProgressTotal); + ADD_POPART_UINT64_OPTION_ALIAS(transitive_closure_optimization_threshold, + transitiveClosureOptimizationThreshold); + + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_virtual_graph_change, + batchSerializationSettings.concatOnVirtualGraphChange); + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_execution_phase_change, batchSerializationSettings.concatOnExecutionPhaseChange); - ADD_POPART_BOOL_OPTION( + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_pipeline_stage_change, batchSerializationSettings.concatOnPipelineStageChange); - ADD_POPART_BOOL_OPTION(strictOpVersions); - ADD_POPART_BOOL_OPTION(opxAliasChecking); - ADD_POPART_BOOL_OPTION(opxModifyChecking); - ADD_POPART_BOOL_OPTION(dotOpNames); - ADD_POPART_BOOL_OPTION(exportPoplarComputationGraph); - ADD_POPART_BOOL_OPTION(exportPoplarVertexGraph); - ADD_POPART_BOOL_OPTION(separateCallOpPdfs); - ADD_POPART_BOOL_OPTION(enableOutlining); - ADD_POPART_BOOL_OPTION(enableOutliningCopyCostPruning); - ADD_POPART_BOOL_OPTION(rearrangeAnchorsOnHost); - ADD_POPART_BOOL_OPTION(enablePrefetchDatastreams); - ADD_POPART_BOOL_OPTION(enableNonStableSoftmax); - ADD_POPART_BOOL_OPTION(enableReplicatedGraphs); - ADD_POPART_BOOL_OPTION(enableGradientAccumulation); - ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter); - ADD_POPART_BOOL_OPTION(enablePipelining); + ADD_POPART_BOOL_OPTION_ALIAS(strict_op_versions, strictOpVersions); + ADD_POPART_BOOL_OPTION_ALIAS(opx_alias_checking, opxAliasChecking); + ADD_POPART_BOOL_OPTION_ALIAS(opx_modify_checking, opxModifyChecking); + ADD_POPART_BOOL_OPTION_ALIAS(dot_op_names, dotOpNames); + ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_computation_graph, + exportPoplarComputationGraph); + ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_vertex_graph, + exportPoplarVertexGraph); + ADD_POPART_BOOL_OPTION_ALIAS(separate_call_op_pdfs, separateCallOpPdfs); + ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining, enableOutlining); + ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining_copy_cost_pruning, + enableOutliningCopyCostPruning); + ADD_POPART_BOOL_OPTION_ALIAS(rearrange_anchors_on_host, + rearrangeAnchorsOnHost); + ADD_POPART_BOOL_OPTION_ALIAS(enable_prefetch_datastreams, + enablePrefetchDatastreams); + ADD_POPART_BOOL_OPTION_ALIAS(enable_non_stable_softmax, + enableNonStableSoftmax); + ADD_POPART_BOOL_OPTION_ALIAS(enable_replicated_graphs, + enableReplicatedGraphs); + ADD_POPART_BOOL_OPTION_ALIAS(enable_gradient_accumulation, + enableGradientAccumulation); + ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter, + instrumentWithHardwareCycleCounter); ADD_POPART_BOOL_OPTION_ALIAS(enable_pipelining, enablePipelining); - ADD_POPART_BOOL_OPTION(disableGradAccumulationTensorStreams); - ADD_POPART_BOOL_OPTION(compileEngine); - ADD_POPART_BOOL_OPTION(constantWeights); - ADD_POPART_BOOL_OPTION(enableEngineCaching); - ADD_POPART_BOOL_OPTION(enableMergeExchange); - ADD_POPART_BOOL_OPTION(enableFloatingPointChecks); - ADD_POPART_BOOL_OPTION(enableStochasticRounding); + ADD_POPART_BOOL_OPTION_ALIAS(disable_grad_accumulation_tensor_streams, + disableGradAccumulationTensorStreams); + ADD_POPART_BOOL_OPTION_ALIAS(compile_engine, compileEngine); + ADD_POPART_BOOL_OPTION_ALIAS(constant_weights, constantWeights); + ADD_POPART_BOOL_OPTION_ALIAS(enable_engine_caching, enableEngineCaching); + ADD_POPART_BOOL_OPTION_ALIAS(enable_merge_exchange, enableMergeExchange); + ADD_POPART_BOOL_OPTION_ALIAS(enable_floating_point_checks, + enableFloatingPointChecks); ADD_POPART_BOOL_OPTION_ALIAS(enable_stochastic_rounding, enableStochasticRounding); - ADD_POPART_BOOL_OPTION(explicitRecomputation); - ADD_POPART_BOOL_OPTION(enableExplicitMainLoops); - ADD_POPART_BOOL_OPTION(useHostCopyOps); - ADD_POPART_BOOL_OPTION(aliasZeroCopy); - ADD_POPART_BOOL_OPTION(delayVarUpdates); - ADD_POPART_BOOL_OPTION(enableFullyConnectedPass); - ADD_POPART_BOOL_OPTION(enableSerializedMatmuls); - ADD_POPART_BOOL_OPTION(enableStableNorm); - ADD_POPART_BOOL_OPTION(decomposeGradSum); - ADD_POPART_BOOL_OPTION(enableDistributedReplicatedGraphs); - ADD_POPART_BOOL_OPTION(groupHostSync); - ADD_POPART_BOOL_OPTION(automaticLossScalingSettings.enabled); - ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter); - ADD_POPART_BOOL_OPTION(enableSupportedDataTypeCasting); - ADD_POPART_BOOL_OPTION(groupNormStridedChannelGrouping); - ADD_POPART_BOOL_OPTION(scheduleNonWeightUpdateGradientConsumersEarly); - - ADD_POPART_DOUBLE_OPTION(outlineSequenceBreakCost); - ADD_POPART_DOUBLE_OPTION(outlineThreshold); - ADD_POPART_DOUBLE_OPTION(timeLimitScheduler); - ADD_POPART_DOUBLE_OPTION(automaticLossScalingSettings.binEdgeLocation); - ADD_POPART_DOUBLE_OPTION( + ADD_POPART_BOOL_OPTION_ALIAS(explicit_recomputation, explicitRecomputation); + ADD_POPART_BOOL_OPTION_ALIAS(enable_explicit_main_loops, + enableExplicitMainLoops); + ADD_POPART_BOOL_OPTION_ALIAS(use_host_copy_ops, useHostCopyOps); + ADD_POPART_BOOL_OPTION_ALIAS(alias_zero_copy, aliasZeroCopy); + ADD_POPART_BOOL_OPTION_ALIAS(delay_var_updates, delayVarUpdates); + ADD_POPART_BOOL_OPTION_ALIAS(enable_fully_connected_pass, + enableFullyConnectedPass); + ADD_POPART_BOOL_OPTION_ALIAS(enable_serialized_matmuls, + enableSerializedMatmuls); + ADD_POPART_BOOL_OPTION_ALIAS(enable_stable_norm, enableStableNorm); + ADD_POPART_BOOL_OPTION_ALIAS(decompose_grad_sum, decomposeGradSum); + ADD_POPART_BOOL_OPTION_ALIAS(enable_distributed_replicated_graphs, + enableDistributedReplicatedGraphs); + ADD_POPART_BOOL_OPTION_ALIAS(group_host_sync, groupHostSync); + ADD_POPART_BOOL_OPTION_ALIAS(automatic_loss_scaling_settings.enabled, + automaticLossScalingSettings.enabled); + ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter, + instrumentWithHardwareCycleCounter); + ADD_POPART_BOOL_OPTION_ALIAS(enable_supported_data_type_casting, + enableSupportedDataTypeCasting); + ADD_POPART_BOOL_OPTION_ALIAS(group_norm_strided_channel_grouping, + groupNormStridedChannelGrouping); + ADD_POPART_BOOL_OPTION_ALIAS( + schedule_non_weight_update_gradient_consumers_early, + scheduleNonWeightUpdateGradientConsumersEarly); + + ADD_POPART_DOUBLE_OPTION_ALIAS(outline_sequence_break_cost, + outlineSequenceBreakCost); + ADD_POPART_DOUBLE_OPTION_ALIAS(outline_threshold, outlineThreshold); + ADD_POPART_DOUBLE_OPTION_ALIAS(time_limit_scheduler, timeLimitScheduler); + ADD_POPART_DOUBLE_OPTION_ALIAS( + automatic_loss_scaling_settings.bin_edge_location, + automaticLossScalingSettings.binEdgeLocation); + ADD_POPART_DOUBLE_OPTION_ALIAS( + automatic_loss_scaling_settings.threshold_upper_count_proportion, automaticLossScalingSettings.thresholdUpperCountProportion); -#undef ADD_POPART_STRING_OPTION -#undef ADD_POPART_DOUBLE_OPTION -#undef ADD_POPART_UINT64_OPTION -#undef ADD_POPART_BOOL_OPTION -#undef ADD_POPART_ENUM_OPTION #undef ADD_POPART_STRING_OPTION_ALIAS #undef ADD_POPART_DOUBLE_OPTION_ALIAS #undef ADD_POPART_UINT64_OPTION_ALIAS @@ -278,14 +308,14 @@ IpuStrategy::IpuStrategy() { }); RegisterSetter( - container_options, "dotChecks", + container_options, "dot_checks", [&](const std::pair& p) { std::uint64_t value = std::stoul(p.first); popart_options.dotChecks.insert(static_cast(value)); }); RegisterGetter( - vector_options_getter, options_type, "dotChecks", "vector", [&]() { + vector_options_getter, options_type, "dot_checks", "vector", [&]() { std::vector res; for (auto x : popart_options.dotChecks) { res.push_back(std::to_string(static_cast(x))); @@ -293,7 +323,7 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "hardwareInstrumentations", + RegisterSetter(container_options, "hardware_instrumentations", [&](const std::pair& p) { std::uint64_t value = std::stoul(p.first); popart_options.hardwareInstrumentations.insert( @@ -301,8 +331,8 @@ IpuStrategy::IpuStrategy() { }); RegisterGetter( - vector_options_getter, options_type, "hardwareInstrumentations", "vector", - [&]() { + vector_options_getter, options_type, "hardware_instrumentations", + "vector", [&]() { std::vector res; for (auto x : popart_options.hardwareInstrumentations) { res.push_back(std::to_string(static_cast(x))); @@ -310,12 +340,12 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "customCodelets", + RegisterSetter(container_options, "custom_codelets", [&](const std::pair& p) { popart_options.customCodelets.push_back(p.first); }); - RegisterGetter(vector_options_getter, options_type, "customCodelets", + RegisterGetter(vector_options_getter, options_type, "custom_codelets", "vector", [&]() { std::vector res; for (auto x : popart_options.customCodelets) { @@ -324,44 +354,44 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "engineOptions", + RegisterSetter(container_options, "engine_options", [&](const std::pair& p) { popart_options.engineOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "engineOptions", "map", + RegisterGetter(map_options_getter, options_type, "engine_options", "map", [&]() { return popart_options.engineOptions; }); - RegisterSetter(container_options, "reportOptions", + RegisterSetter(container_options, "report_options", [&](const std::pair& p) { popart_options.reportOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "reportOptions", "map", + RegisterGetter(map_options_getter, options_type, "report_options", "map", [&]() { return popart_options.reportOptions; }); - RegisterSetter(container_options, "convolutionOptions", + RegisterSetter(container_options, "convolution_options", [&](const std::pair& p) { popart_options.convolutionOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "convolutionOptions", "map", + RegisterGetter(map_options_getter, options_type, "convolution_options", "map", [&]() { return popart_options.convolutionOptions; }); - RegisterSetter(container_options, "lstmOptions", + RegisterSetter(container_options, "lstm_options", [&](const std::pair& p) { popart_options.lstmOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "lstmOptions", "map", + RegisterGetter(map_options_getter, options_type, "lstm_options", "map", [&]() { return popart_options.lstmOptions; }); - RegisterSetter(container_options, "gclOptions", + RegisterSetter(container_options, "gcl_options", [&](const std::pair& p) { popart_options.gclOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "gclOptions", "map", + RegisterGetter(map_options_getter, options_type, "gcl_options", "map", [&]() { return popart_options.gclOptions; }); } @@ -415,21 +445,21 @@ void IpuStrategy::SetTensorLocation(const std::string& tensor, "Unknown tensor location: %s", tensor)); } - if (opt == "minElementsForOffChip") { + if (opt == "min_elements_for_off_chip") { settings->minElementsForOffChip = value; - } else if (opt == "minElementsForReplicatedTensorSharding") { + } else if (opt == "min_elements_for_replicated_tensor_sharding") { settings->minElementsForReplicatedTensorSharding = value; - } else if (opt == "onChip") { + } else if (opt == "on_chip") { settings->location.storage = value > 0 ? popart::TensorStorage::OnChip : popart::TensorStorage::OffChip; - } else if (opt == "useReplicatedTensorSharding") { + } else if (opt == "use_replicated_tensor_sharding") { settings->location.replicatedTensorSharding = value > 0 ? popart::ReplicatedTensorSharding::On : popart::ReplicatedTensorSharding::Off; - } else if (opt == "useIOTilesToLoad") { + } else if (opt == "use_io_tiles_to_load") { settings->location.loadTileSet = value > 0 ? popart::TileSet::IO : popart::TileSet::Compute; - } else if (opt == "useIOTilesToStore") { + } else if (opt == "use_io_tiles_to_store") { settings->location.storageTileSet = value > 0 ? popart::TileSet::IO : popart::TileSet::Compute; } else { @@ -464,6 +494,20 @@ std::string IpuStrategy::GetOptionType(const std::string& option) { return options_type[option]; } +std::vector IpuStrategy::GetAllOptionNames() { + std::vector names; + for (auto& option : options_getter) { + names.push_back(option.first); + } + for (auto& option : vector_options_getter) { + names.push_back(option.first); + } + for (auto& option : map_options_getter) { + names.push_back(option.first); + } + return names; +} + void IpuStrategy::EnablePattern(const std::string& t) { VLOG(10) << "enable popart pattern: " << t; popart_patterns.enablePattern(t, true); diff --git a/paddle/fluid/platform/device/ipu/ipu_strategy.h b/paddle/fluid/platform/device/ipu/ipu_strategy.h index 64436dc14fec3393b0a2a4473ad436d7d08f5217..571fb1e163718388a779e128fb6aaf76659d7183 100644 --- a/paddle/fluid/platform/device/ipu/ipu_strategy.h +++ b/paddle/fluid/platform/device/ipu/ipu_strategy.h @@ -24,7 +24,8 @@ namespace paddle { namespace platform { namespace ipu { -struct IpuStrategy { +class IpuStrategy { + public: IpuStrategy(); // TODO(alleng) create PaddleOptions @@ -75,22 +76,30 @@ struct IpuStrategy { // custom ops std::vector custom_ops; - private: - std::map> bool_options; - std::map> uint64_options; - std::map> double_options; - std::map> string_options; - std::map)>> - container_options; + public: + void AddBoolOption(const std::string &option, bool value); + void AddUint64Option(const std::string &option, std::uint64_t value); + void AddDoubleOption(const std::string &option, double value); + void AddStringOption(const std::string &option, const std::string &value); + void InsertStringOption(const std::string &option, const std::string &value); + void InsertStringPairOption(const std::string &option, const std::string &key, + const std::string &value); + void SetTensorLocation(const std::string &tensor, const std::string &option, + std::uint64_t value); + void AddCustomOp(const std::string &paddle_op, const std::string &popart_op, + const std::string &domain, int version); - std::map> options_getter; - std::map()>> - vector_options_getter; - std::map()>> - map_options_getter; - std::map options_type; + std::string GetOption(const std::string &); + std::vector GetVectorOption(const std::string &); + std::map GetMapOption(const std::string &); + std::string GetOptionType(const std::string &); + std::vector GetAllOptionNames(); + + void EnablePattern(const std::string &t); + void DisablePattern(const std::string &t); + const bool IsPatternEnabled(const std::string &t); + private: template void set( const std::string &key, ValueType value, @@ -117,27 +126,20 @@ struct IpuStrategy { return it->second(); } - public: - void AddBoolOption(const std::string &option, bool value); - void AddUint64Option(const std::string &option, std::uint64_t value); - void AddDoubleOption(const std::string &option, double value); - void AddStringOption(const std::string &option, const std::string &value); - void InsertStringOption(const std::string &option, const std::string &value); - void InsertStringPairOption(const std::string &option, const std::string &key, - const std::string &value); - void SetTensorLocation(const std::string &tensor, const std::string &option, - std::uint64_t value); - void AddCustomOp(const std::string &paddle_op, const std::string &popart_op, - const std::string &domain, int version); - - std::string GetOption(const std::string &); - std::vector GetVectorOption(const std::string &); - std::map GetMapOption(const std::string &); - std::string GetOptionType(const std::string &); + std::map> bool_options; + std::map> uint64_options; + std::map> double_options; + std::map> string_options; + std::map)>> + container_options; - void EnablePattern(const std::string &t); - void DisablePattern(const std::string &t); - const bool IsPatternEnabled(const std::string &t); + std::map> options_getter; + std::map()>> + vector_options_getter; + std::map()>> + map_options_getter; + std::map options_type; }; } // namespace ipu diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6e553ad2e60e292881fa8bb0294ea2a247656b67..3d8815e2eb61b53a6c8447fc8ce09a9c113963f2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3919,6 +3919,8 @@ All parameter, weight, gradient are variables in Paddle. } return res; }) + .def("get_all_option_names", + &platform::ipu::IpuStrategy::GetAllOptionNames) .def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern) .def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern) .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h index 843b19d217feb332a278c80378aaeb856442de9a..ca8a22a7e75d33de6e9f510aea5aab8c24255c36 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h @@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { bool IsDenseTensorOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override; + bool IsForInferShape() const override { return false; } + private: mlir::Operation* op_; const std::unordered_map& input_map_; diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index af29b3bab5c3cc4b2e1caeb4eee9689179464d01..f625d57df2ef2dc2f9505853dc5e07e5d9e0022e 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -91,6 +91,10 @@ class ArgumentMappingContext { virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; + + // use this function to mark it comes from InferShapeArgumentMappingContext + // and will be used in infershape + virtual bool IsForInferShape() const = 0; }; } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4696187bd2382a9d81400a0fd088f9d0013ff506..983e0162264926f3f165404834e38cbe6519bbf2 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } -/* Why not use ReduceInferMeta directly? +/* Why not use ReduceInferMetaBase directly? Because we need make InferMetaFunction's args follow the design of api.yaml */ void SumInferMeta(const MetaTensor& x, @@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x, DataType dtype, bool keep_dim, MetaTensor* out) { - ReduceInferMetaBase(x, axis, keep_dim, dtype, out); + bool reduce_all = false; + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out); } void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, bool keep_dim, + bool reduce_all, DataType dtype, MetaTensor* out) { - bool reduce_all = true; - std::set dims_set(axis.begin(), axis.end()); + auto x_rank = x.dims().size(); + + std::vector formated_axis = axis; + for (size_t i = 0; i < axis.size(); ++i) { + PADDLE_ENFORCE_LT(axis[i], + x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [-dimension(X), dimension(X)] " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + PADDLE_ENFORCE_GE(axis[i], + -x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [-dimension(X), dimension(X)] " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + + if (axis[i] < 0) { + formated_axis[i] = axis[i] + x_rank; + } + } + + bool full_dim = true; + std::set dims_set(formated_axis.begin(), formated_axis.end()); for (int64_t i = 0; i < x.dims().size(); ++i) { if (dims_set.find(i) == dims_set.end()) { - reduce_all = false; + full_dim = false; break; } } + reduce_all = reduce_all || full_dim; std::vector out_dim_vector; if (keep_dim) { @@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out) { - ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out); +void MeanRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out) { + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); +} + +void MeanInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out) { + bool reduce_all = false; + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); } void TransferLayoutInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index b3929b9d2b47f87ab0f7b42ed74c2881c076f7d9..a2d779e0f709318950bb1e651de66f4c17045616 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -86,13 +86,20 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, bool keep_dim, + bool reduce_all, DataType dtype, MetaTensor* out); -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out); +void MeanRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out); + +void MeanInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out); void SumInferMeta(const MetaTensor& x, const std::vector& axis, diff --git a/paddle/phi/kernels/cpu/selu_grad_kernel.cc b/paddle/phi/kernels/cpu/selu_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..32101b19132825d77534a55b857c2a169e94e9ac --- /dev/null +++ b/paddle/phi/kernels/cpu/selu_grad_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selu_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + selu_grad, CPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/selu_kernel.cc b/paddle/phi/kernels/cpu/selu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc5a0616a725d17b8846b066e071ab01809aa655 --- /dev/null +++ b/paddle/phi/kernels/cpu/selu_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selu_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/selu_kernel_impl.h" + +PD_REGISTER_KERNEL(selu, CPU, ALL_LAYOUT, phi::SeluKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/selu_grad_kernel.cu b/paddle/phi/kernels/gpu/selu_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0ed299413c1726f617dee9a8b5b4bf1d79d30efe --- /dev/null +++ b/paddle/phi/kernels/gpu/selu_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selu_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + selu_grad, GPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/selu_kernel.cu b/paddle/phi/kernels/gpu/selu_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..99303d8c18a97893a939a8f358bac02603fae329 --- /dev/null +++ b/paddle/phi/kernels/gpu/selu_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selu_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/selu_kernel_impl.h" + +PD_REGISTER_KERNEL(selu, GPU, ALL_LAYOUT, phi::SeluKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/selu_grad_kernel_impl.h b/paddle/phi/kernels/impl/selu_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d09c87b0a4ed268dc6e17815a3cc0072c54e382f --- /dev/null +++ b/paddle/phi/kernels/impl/selu_grad_kernel_impl.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/impl/selu_kernel_impl.h" + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void SeluGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + float scale, + float alpha, + DenseTensor* dx) { + auto dx_ptr = dev_ctx.template Alloc(dx); + SeluGradFunctor functor( + out.data(), dout.data(), alpha, scale, dx_ptr); + size_t limit = static_cast(out.numel()); + paddle::platform::ForRange for_range(dev_ctx, limit); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/selu_kernel_impl.h b/paddle/phi/kernels/impl/selu_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..888bac42bfd91c99af2599ecef038c6f5a5424c1 --- /dev/null +++ b/paddle/phi/kernels/impl/selu_kernel_impl.h @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +struct SeluFunctor { + SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) + : x_data_ptr_(x_data_ptr), + alpha_(alpha), + scale_(scale), + y_data_ptr_(y_data_ptr) {} + + HOSTDEVICE void operator()(size_t idx) const { + T x_ele = x_data_ptr_[idx]; + if (x_ele <= 0) { + x_ele = alpha_ * paddle::operators::real_exp(x_ele) - alpha_; + } + y_data_ptr_[idx] = scale_ * x_ele; + } + const T* x_data_ptr_; + const float alpha_; + const float scale_; + T* y_data_ptr_; +}; + +template +struct SeluGradFunctor { + SeluGradFunctor(const T* y_data_ptr, + const T* dy_data_ptr, + float alpha, + float scale, + T* dx_data_ptr) + : y_data_ptr_(y_data_ptr), + dy_data_ptr_(dy_data_ptr), + alpha_(alpha), + scale_(scale), + la_(alpha * scale), + dx_data_ptr_(dx_data_ptr) {} + + HOSTDEVICE void operator()(size_t idx) const { + T y_ele = y_data_ptr_[idx]; + T dy_ele = dy_data_ptr_[idx]; + + float tmp = scale_; + if (y_ele <= 0) { + tmp = y_ele + la_; + } + dx_data_ptr_[idx] = dy_ele * tmp; + } + const T* y_data_ptr_; + const T* dy_data_ptr_; + const float alpha_; + const float scale_; + const float la_; + T* dx_data_ptr_; +}; + +template +void SeluKernel(const Context& dev_ctx, + const DenseTensor& x, + float scale, + float alpha, + DenseTensor* out) { + auto out_ptr = dev_ctx.template Alloc(out); + SeluFunctor functor(x.data(), alpha, scale, out_ptr); + size_t limit = static_cast(x.numel()); + paddle::platform::ForRange for_range(dev_ctx, limit); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/math_kernel.h b/paddle/phi/kernels/math_kernel.h index c6036f4a0421b87860dcf9301c77f689fab2c952..342393d79bd4d3729afdae45b423b50827a37d61 100644 --- a/paddle/phi/kernels/math_kernel.h +++ b/paddle/phi/kernels/math_kernel.h @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto dense_out = phi::Empty(dev_ctx); MetaTensor meta_out(&dense_out); - ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out); + ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out); MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 4f3c069f3b249161eb83698c4ded150b8f003b14..19427551fb3f0e55f1ad4302bdb687bdf54e920d 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -136,7 +136,9 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { return shared_memory[threadIdx.x]; } -// Swap data +/** + * @brief Swap data + */ template __device__ __forceinline__ void Swap(T* first_value, T* second_value) { T t_value; @@ -145,7 +147,9 @@ __device__ __forceinline__ void Swap(T* first_value, T* second_value) { (*second_value) = t_value; } -// swap with monotonic_type +/** + * @brief Swap data according to monotonic_type. + */ template __device__ __forceinline__ void Comparator(T* first_value, T* second_value, @@ -155,6 +159,9 @@ __device__ __forceinline__ void Comparator(T* first_value, } } +/** + * @brief Swap data and data index according to monotonic_type. + */ template __device__ __forceinline__ void ComparatorWithIndex(T* first_value, @@ -170,6 +177,18 @@ __device__ __forceinline__ void ComparatorWithIndex(T* first_value, } } +/** + * @brief get the last pow of 2 + */ +__device__ inline int GetLastPow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + } // namespace details /** @@ -453,6 +472,29 @@ __device__ __forceinline__ void Reduce(T* out, } } +/* +* @brief Fill register with a constant according to OpFunc +* +* @template paraments +* InT: The data type of in1 and in2. +* OutT: The data type of out. +* NX: The number of data columns loaded by each thread. +* NY: The number of data rows loaded by each thread. +* BlockSize: Identifies the current device thread index method. Currently only +* GPU was supported. +* OpFunc: Compute functor which has an operator() as following +* template +* struct XxxFunctor { +* HOSTDEVICE InT operator()() +* const { +* return a; +* } +* }; +* +* @param +* out: The register pointer of out, the size is NX * NY. +* compute: Compute function which was declared like OpFunc(). +*/ template +* struct XxxFunctor { +* HOSTDEVICE InT operator()(StateType state) +* const { +* return ranomd(state); // Returns ReturnsCount random numbers with +* data type T +* } +* }; +* +* @param +* out: The register pointer of out, the size is NX * NY. +* compute: Compute function which was declared like OpFunc(). +*/ + template +/* +* @brief Complete the prefix and in the block, each thread calculates 2 data, +* the size of out and in is 2, and BlockDim.x must be less then 512. +* +* @template paraments +* InT: the type of input register. +* OutT: the type of out register. +* BlockSize: Identifies the current device thread index method. Currently only +* GPU was supported. +* OpFunc: Compute functor which has an operator() as following +* template +* struct XxxFunctor { +* HOSTDEVICE InT operator()(T a, T b) +* const { +* return a + b; +* } +* }; +* +* @param +* out: The register pointer of out, the size is 2; +* in: The register pointer of input, the size is 2; +* compute: Compute function which was declared like OpFunc(). +*/ + +#define SHARED_SIZE_LIMIT 512 +template __device__ __forceinline__ void Cumsum(OutT* out, const InT* in, OpFunc compute) { - __shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32]; + constexpr int kSize = SHARED_SIZE_LIMIT * 2 + (SHARED_SIZE_LIMIT * 2) / 32; + __shared__ InT temp[kSize]; + int stride_size = blockDim.x; int tidx = threadIdx.x; temp[tidx + tidx / 32] = in[0]; - temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1]; - for (int stride = 1; stride <= blockDim.x; stride *= 2) { + temp[stride_size + tidx + (stride_size + tidx) / 32] = in[1]; + for (int stride = 1; stride <= stride_size; stride *= 2) { __syncthreads(); int index = (tidx + 1) * 2 * stride - 1; if (index < (blockDim.x * 2)) { - temp[index + index / 32] += temp[index - stride + (index - stride) / 32]; + temp[index + index / 32] = + compute(temp[index + index / 2], + temp[index - stride + (index - stride) / 32]); } } for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) { __syncthreads(); int index = (tidx + 1) * 2 * stride - 1; if ((index + stride) < (blockDim.x * 2)) { - temp[index + stride + (stride + index) / 32] += - temp[index + (index) / 32]; + temp[index + stride + (stride + index) / 32] = + compute(temp[index + stride + (stride + index) / 32], + temp[index + (index) / 32]); } } __syncthreads(); out[0] = static_cast(temp[tidx + tidx / 32]); out[1] = - static_cast(temp[tidx + shared_size + (tidx + shared_size) / 32]); + static_cast(temp[tidx + stride_size + (tidx + stride_size) / 32]); } - -#define SHARED_SIZE_LIMIT \ - 1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must - // larger than blockDim.x * 2 -// if monotonic_type = 1 then increase -// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2 -// == 1 the increase -template -__device__ __forceinline__ void Sort(T* dst, - const T* src_data, +#undef SHARED_SIZE_LIMIT + +/* +* @brief Sort data in this block, each thread calculates 2 data, the size of out +* and in is 2, and BlockDim.x must be less then 512. +* +* @template paraments +* InT: the type of input register. +* OutT: the type of out register. +* BlockSize: Identifies the current device thread index method. Currently only +* GPU was supported. +* +* @param +* out: The register pointer of out, the size is 2. +* in: The register pointer of input, the size is 2. +* num: The num of this block +* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles +* sorted in escending. +*/ +#define SHARED_SIZE_LIMIT 1024 +// each thread load 2 data from global memory so SHARED_SIZE_LIMIT must +// larger than blockDim.x * 2 +template +__device__ __forceinline__ void Sort(OutT* out, + const InT* in, int num, int monotonic_type) { - // todo: set num = Pow2(num) + int upper_bound = blockDim.x; + // update upper_bound + upper_bound = std::min(details::GetLastPow2(num), upper_bound); // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 - __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than - // blockDim * 2 - // Copy value and index from src and src_index - value[threadIdx.x] = src_data[0]; - value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; + __shared__ InT value[SHARED_SIZE_LIMIT]; + int stride_size = blockDim.x; + // shareMem's size must larger than blockDim * 2 + // Copy value from in + value[threadIdx.x] = in[0]; + value[threadIdx.x + stride_size] = in[1]; // make bitonicSort - for (int size = 2; size < num; size <<= 1) { + for (int size = 2; size < upper_bound; size <<= 1) { int bitonic_type = (threadIdx.x & (size / 2)) != 0; for (int stride = size / 2; stride > 0; stride >>= 1) { __syncthreads(); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); - details::Comparator(&value[pos], &value[pos + stride], bitonic_type); + details::Comparator(&value[pos], &value[pos + stride], bitonic_type); } } // last sort - for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { + for (int stride = stride_size; stride > 0; stride >>= 1) { __syncthreads(); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); // last sort when monotonic_type = 1 then increase - details::Comparator(&value[pos], &value[pos + stride], monotonic_type); + details::Comparator(&value[pos], &value[pos + stride], monotonic_type); } __syncthreads(); - dst[0] = value[threadIdx.x]; - dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; + out[0] = static_cast(value[threadIdx.x]); + out[1] = static_cast(value[threadIdx.x + stride_size]); } -template -__device__ __forceinline__ void Sort(T* dst, - IndexType* dst_index, - const T* src_data, - IndexType* src_index, +/* +* @brief Sort data with data_index in this block, each thread calculates 2 data, +* the size of out and in is 2, and BlockDim.x must be less then 512. +* +* @template paraments +* InT: The type of input register. +* OutT: The type of out register. +* IndexType: The type of index. +* BlockSize: Identifies the current device thread index method. Currently only +* GPU was supported. +* +* @param +* out: The register pointer of out, the size is 2. +* out_index: The register pointer of out_index, the size is 2. +* in: The register pointer of input, the size is 2. +* in_index: The register pointer of in_index, the size is 2. +* num: The num of this block. +* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles +* sorted in escending. +*/ +template +__device__ __forceinline__ void Sort(OutT* out, + IndexType* out_index, + const InT* in, + IndexType* in_index, int num, int monotonic_type) { - // todo: set num = Pow2(num) + int upper_bound = blockDim.x; + // update upper_bound + upper_bound = std::min(details::GetLastPow2(num), upper_bound); // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 - __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than - // blockDim * 2 + __shared__ InT value[SHARED_SIZE_LIMIT]; + // shareMem's size must larger than blockDim * 2 __shared__ IndexType index[SHARED_SIZE_LIMIT]; - // Copy value and index from src and src_index - value[threadIdx.x] = src_data[0]; - value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; + // Copy value and index from in and in_index + int stride_size = blockDim.x; + value[threadIdx.x] = in[0]; + value[threadIdx.x + stride_size] = in[1]; // index - index[threadIdx.x] = src_index[0]; - index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1]; + index[threadIdx.x] = in_index[0]; + index[threadIdx.x + stride_size] = in_index[1]; // make bitonicSort - for (int size = 2; size < num; size <<= 1) { + for (int size = 2; size < upper_bound; size <<= 1) { int bitonic_type = (threadIdx.x & (size / 2)) != 0; for (int stride = size / 2; stride > 0; stride >>= 1) { __syncthreads(); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); - details::ComparatorWithIndex(&value[pos], - &value[pos + stride], - &index[pos], - &index[pos + stride], - bitonic_type); + details::ComparatorWithIndex(&value[pos], + &value[pos + stride], + &index[pos], + &index[pos + stride], + bitonic_type); } } - for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { + for (int stride = stride_size; stride > 0; stride >>= 1) { __syncthreads(); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); // last sort when monotonic_type = 1 then increase - details::ComparatorWithIndex(&value[pos], - &value[pos + stride], - &index[pos], - &index[pos + stride], - monotonic_type); + details::ComparatorWithIndex(&value[pos], + &value[pos + stride], + &index[pos], + &index[pos + stride], + monotonic_type); } __syncthreads(); - dst[0] = value[threadIdx.x]; - dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; - dst_index[0] = index[threadIdx.x]; - dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; + out[0] = static_cast(value[threadIdx.x]); + out[1] = static_cast(value[threadIdx.x + stride_size]); + out_index[0] = index[threadIdx.x]; + out_index[1] = index[threadIdx.x + stride_size]; +} + +template +HOSTDEVICE __forceinline__ void OperatorTernary( + OutT* out, const T1* in1, const T2* in2, OpFunc func, int num) { + func(out, in1, in2, num); +} + +template +HOSTDEVICE __forceinline__ void OperatorBinary(OutT* out, + const InT* in, + OpFunc func, + int num) { + func(out, in, num); } } // namespace kps diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index a445f4a02ea714b2b2851d4de178b5ba76f5678d..1f4ef2ed932e9f986e0c59d9b4da891817cf7afe 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -348,6 +348,29 @@ __device__ __forceinline__ void Reduce(T* out, } } +/* +* @brief Fill register with a constant according to OpFunc +* +* @template paraments +* InT: The data type of in1 and in2. +* OutT: The data type of out. +* NX: The number of data columns loaded by each thread. +* NY: The number of data rows loaded by each thread. +* BlockSize: Identifies the current device thread index method. For xpu, +* core_id() is used as the index. +* OpFunc: Compute functor which has an operator() as following +* template +* struct XxxFunctor { +* HOSTDEVICE InT operator()() +* const { +* return a; +* } +* }; +* +* @param +* out: The register pointer of out, the size is NX * NY. +* compute: Compute function which was declared like OpFunc(). +*/ template or std::tuple + * Index: The index of data stored in dst. + * BlockSize: Identifies the current device thread index method. For GPU, + * threadIdx.x is used as the thread index. Currently only GPU was supported. + * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. + * When the number of data processed by this block is less than + * NX x NY x blockDim.x, boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The data pointer of the current block. + * size: The current block needs to load size data continuously. */ template __device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { int thread_offset = block_offset + threadIdx.x * NX; diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 75b2dbaf7e6a305fdb32ae3738944922fb4a93a5..53a8b7d0c9ef9489056ab293d97e5767b23531fe 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -244,6 +244,24 @@ __device__ __inline__ void ReadData(T* dst, /** * @brief Read 1D data from global memory to register. The difference * from the above function is that it supports different data types of inputs. + * + * @template paraments + * T: The type of data. + * NX: Each thread load NX data from global memory continuously. + * NY: Each thread need to load NY rows, only NY = 1 was supported. + * ArgsT: The Type if dst, ArgsT can be std::tuple or std::tuple + * Index: The index of data stored in dst. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. + * When the number of data processed by this block is less than + * NX x NY x blockDim.x, boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The data pointer of the current block. + * size: The current block needs to load size data continuously. */ template +__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { + int thread_offset = block_offset + core_id() * NX; +#pragma unroll + for (int nx = 0; nx < NX; ++nx) { + dst[nx] = static_cast(thread_offset + nx); + } +} + } // namespace kps } // namespace phi diff --git a/paddle/phi/kernels/selu_grad_kernel.h b/paddle/phi/kernels/selu_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..42cde6deabe1c27c42397ec97221ff6790c8ed7a --- /dev/null +++ b/paddle/phi/kernels/selu_grad_kernel.h @@ -0,0 +1,29 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SeluGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& d_out, + float scale, + float alpha, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/selu_kernel.h b/paddle/phi/kernels/selu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cd5d27e98ccc109a07c3ea170d1d5953757b2152 --- /dev/null +++ b/paddle/phi/kernels/selu_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SeluKernel(const Context& dev_ctx, + const DenseTensor& x, + float scale, + float alpha, + DenseTensor* out); +} // phi diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 74704671f8b5d244b2c3b07ada5e592a8c64da27..6395486ed2b724e522ca60cd86e104516325f1bd 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -17,28 +17,36 @@ limitations under the License. */ namespace phi { KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { - bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); if (ctx.IsDenseTensorInput("X")) { - if (!reduce_all) { - return KernelSignature( - "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "sum_raw" KernelSignature. + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the "sum_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature("sum_raw", + {"X"}, + {"dim", "keep_dim", "reduce_all", "out_dtype"}, + {"Out"}); } - return KernelSignature("sum_raw", - {"X"}, - {"dim", "keep_dim", "reduce_all", "out_dtype"}, - {"Out"}); + return KernelSignature( + "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); } return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { - bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); if (ctx.IsDenseTensorInput("X")) { - if (!reduce_all) { - return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "mean_raw" KernelSignature. + // And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the + // "mean_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature( + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } - return KernelSignature( - "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); } return KernelSignature("unregistered", {}, {}, {}); } diff --git a/paddle/phi/ops/compat/selu_sig.cc b/paddle/phi/ops/compat/selu_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..23f5cc34515b4aba482e2cfe3d6e0d148e2d97b2 --- /dev/null +++ b/paddle/phi/ops/compat/selu_sig.cc @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SeluGradGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("selu_grad", + {"Out", GradVarName("Out")}, + {"scale", "alpha"}, + {GradVarName("X")}); +} +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(selu_grad, phi::SeluGradGradOpArgumentMapping); diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index fcd2d397fa2db593dffe6b0c898efedc2e62cd81..06048f33d940a28ddf9e3aa488a6e24a9e4a93b6 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { return selected_rows_outputs.count(name) > 0; } + bool IsForInferShape() const override { return false; } + private: const std::unordered_set dense_tensor_inputs; const std::unordered_set selected_rows_inputs; diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 4d7451f435271b4aaca3010e643ddcb5fbb28191..8528ba34e210e4d86f32211022693db5f37f1326 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -1266,7 +1266,7 @@ function card_test() { elif [ "${WITH_ASCEND_CL}" == "ON" ];then CUDA_DEVICE_COUNT=1 elif [ "${WITH_ROCM}" == "ON" ];then - CUDA_DEVICE_COUNT=4 + CUDA_DEVICE_COUNT=$(rocm-smi -i | grep GPU | wc -l) else CUDA_DEVICE_COUNT=$(nvidia-smi -L | wc -l) fi diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 494ea9697971974d20c917006225df55f531ff70..f75a0fa50a59c1dd570f3b35ff5b3c9108564e78 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -25,6 +25,12 @@ function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_pa _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) endfunction() +function(download_data install_dir url data_file check_sum) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${url} ${data_file} ${check_sum}) + endif() +endfunction() + function(download_quant_data install_dir data_file check_sum) if (NOT EXISTS ${install_dir}/${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) @@ -290,8 +296,9 @@ if(LINUX AND WITH_MKLDNN) ### PTQ INT8 # PTQ int8 lstm model - set(LSTM_DATA_ARCHIVE "unittest_model_data/quant_lstm_input_data.tar.gz") - download_quant_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_DATA_ARCHIVE} add84c754e9b792fea1fbd728d134ab7) + set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz") + set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data") + download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE} add84c754e9b792fea1fbd728d134ab7) set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743) inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) diff --git a/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..959700ad743b40420200b56055354279386a9a7c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt @@ -0,0 +1,8 @@ +if(WITH_IPU) + file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") + string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + + foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) + endforeach(TEST_OP) +endif() diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 0fc56726c5d0e57f88ec38db95c76495931d0f26..13b880b28bf851affc815c7f9df8779370619178 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -423,6 +423,14 @@ class TestMoveAxis(unittest.TestCase): self.assertEqual(np.array_equal(out.numpy(), expected), True) paddle.enable_static() + def test_moveaxis3(self): + paddle.disable_static() + x = paddle.to_tensor( + [[1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j]]) + out = x.moveaxis(0, 1) + self.assertEqual(out.shape, [2, 3]) + paddle.enable_static() + def test_error(self): x = paddle.randn([2, 3, 4, 5]) # src must have the same number with dst diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 541df6659c2d5620e449d2aee0707987ee43d042..dbd40c349bbc81d39b8a929ee5b3e7b81a083406 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -51,6 +51,10 @@ class TestVarBase(unittest.TestCase): np.array_equal(x.numpy(), np.array([1.2], 'float16'))) self.assertEqual(x.dtype, core.VarDesc.VarType.FP16) + # set_default_dtype take effect on int + x = paddle.to_tensor(1, place=place) + self.assertTrue(x.dtype, core.VarDesc.VarType.INT64) + # set_default_dtype take effect on float x = paddle.to_tensor(1.2, place=place, stop_gradient=False) self.assertTrue( diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ae563e641e3c854e6d516ada20beb2dafb151578..bddc45bc9612c34c4fab1b6c4ce6ae1fc47a1052 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace" ) - #Todo(zhouwei): Support allocate tensor on any other specified card - if isinstance(place, core.CUDAPlace) and isinstance( - _current_expected_place(), core.CUDAPlace) and place._get_device_id( - ) != _current_expected_place()._get_device_id(): - place = _current_expected_place() - if not isinstance(data, np.ndarray): def _handle_dtype(data, dtype): @@ -139,7 +133,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): data.stop_gradient = stop_gradient return data elif isinstance(data, (core.LoDTensor, core.Tensor)): - # Note(zhouwei25): should't expose it to users, just for internal use. + # should't expose it to users, just for internal use. # convert core.Tensor/core.LoDTensor to VarBase first # Currenly, there is no copy when places are same data = paddle.Tensor(data) @@ -152,15 +146,20 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): raise TypeError( "Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor". format(type(data))) - if not dtype and data.dtype in [ - 'float16', 'float32', 'float64', 'complex64', 'complex128' - ]: - default_type = paddle.get_default_dtype() - if np.iscomplexobj(data): - default_type = 'complex64' if default_type in [ - 'float16', 'float32' - ] else 'complex128' - data = data.astype(default_type) + if not dtype: + if data.dtype in [ + 'float16', 'float32', 'float64', 'complex64', 'complex128' + ]: + default_type = paddle.get_default_dtype() + if np.iscomplexobj(data): + default_type = 'complex64' if default_type in [ + 'float16', 'float32' + ] else 'complex128' + data = data.astype(default_type) + # Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they. + if data.dtype in ['int32']: + default_type = "int64" + data = data.astype(default_type) if dtype and convert_dtype(dtype) != data.dtype: data = data.astype(convert_dtype(dtype)) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 53bb9a8807562866d810bcf36a0329b7cadd7ebd..fbd6197c1b92ee8481a1ce6f4a2cec8482eaefb0 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None): out, _ = _C_ops.transpose2(x, 'axis', perm) return out - check_variable_and_dtype( - x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'moveaxis') + check_variable_and_dtype(x, 'x', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], 'moveaxis') helper = LayerHelper('moveaxis', **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 7ea8493b67fd6dec6f46df8ca854bbd700ffbfa6..45a6aae5e6dddd89509c120ffbb67540669f796a 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -124,7 +124,7 @@ args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) output : Tensor infer_meta : - func : ReduceInferMeta + func : MeanInferMeta kernel : func : mean