diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 93bc2c02d57cb7b57cf48d6f5c34a27a97637377..14997dd9610138e32a45ef17abc9276cd1dad172 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -125,6 +125,15 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } +#endif +#ifdef PADDLE_WITH_IPU + if (platform::is_ipu_place(expected_kernel_key.place_)) { + VLOG(3) << "pten missing IPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), + kernel_key.dtype()); + } #endif return phi::KernelKey(); } diff --git a/paddle/fluid/platform/device/ipu/ipu_strategy.cc b/paddle/fluid/platform/device/ipu/ipu_strategy.cc index 943dfcc6cffb875fc3cebfc88e35adeaba47fd63..e806b0b30e4e03759847cc2e1838171020a064b1 100644 --- a/paddle/fluid/platform/device/ipu/ipu_strategy.cc +++ b/paddle/fluid/platform/device/ipu/ipu_strategy.cc @@ -120,121 +120,151 @@ IpuStrategy::IpuStrategy() { RegisterGetter(options_getter, options_type, #name, "string", \ [&]() { return popart_options.aliased_name; }) -#define ADD_POPART_ENUM_OPTION(name, EnumType) \ - ADD_POPART_ENUM_OPTION_ALIAS(name, name, EnumType) - -#define ADD_POPART_BOOL_OPTION(name) ADD_POPART_BOOL_OPTION_ALIAS(name, name) - -#define ADD_POPART_UINT64_OPTION(name) \ - ADD_POPART_UINT64_OPTION_ALIAS(name, name) - -#define ADD_POPART_DOUBLE_OPTION(name) \ - ADD_POPART_DOUBLE_OPTION_ALIAS(name, name) - -#define ADD_POPART_STRING_OPTION(name) \ - ADD_POPART_STRING_OPTION_ALIAS(name, name) - - ADD_POPART_ENUM_OPTION(autodiffSettings.stitchStrategy, - AutodiffStitchStrategy); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.transformContext, - BatchSerializationTransformContext); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.method, - BatchSerializationMethod); - ADD_POPART_ENUM_OPTION(batchSerializationSettings.batchSchedule, - BatchSerializationBatchSchedule); - ADD_POPART_ENUM_OPTION(autoRecomputation, RecomputationType); - ADD_POPART_ENUM_OPTION(mergeVarUpdate, MergeVarUpdateType); - ADD_POPART_ENUM_OPTION(virtualGraphMode, VirtualGraphMode); - ADD_POPART_ENUM_OPTION(syntheticDataMode, SyntheticDataMode); - ADD_POPART_ENUM_OPTION(subgraphCopyingStrategy, SubgraphCopyingStrategy); - ADD_POPART_ENUM_OPTION(accumulationAndReplicationReductionType, - ReductionType); - ADD_POPART_ENUM_OPTION(meanAccumulationAndReplicationReductionStrategy, - MeanReductionStrategy); - - ADD_POPART_STRING_OPTION(logDir); - ADD_POPART_STRING_OPTION(cachePath); - ADD_POPART_STRING_OPTION(partialsTypeMatMuls); - ADD_POPART_STRING_OPTION(customCodeletCompileFlags); - ADD_POPART_STRING_OPTION(serializedPoprithmsShiftGraphsDir); - ADD_POPART_STRING_OPTION(kahnTieBreaker); - - ADD_POPART_UINT64_OPTION(executionPhaseSettings.phases); - ADD_POPART_UINT64_OPTION(executionPhaseSettings.stages); - ADD_POPART_UINT64_OPTION(batchSerializationSettings.factor); - ADD_POPART_UINT64_OPTION(firstDotOp); - ADD_POPART_UINT64_OPTION(finalDotOp); - ADD_POPART_UINT64_OPTION(numIOTiles); - ADD_POPART_UINT64_OPTION(mergeVarUpdateMemThreshold); - ADD_POPART_UINT64_OPTION(looseThresholdAtPeak); - ADD_POPART_UINT64_OPTION(accumulationFactor); - ADD_POPART_UINT64_OPTION(swapLimitScheduler); - ADD_POPART_UINT64_OPTION(globalReplicationFactor); - ADD_POPART_UINT64_OPTION(globalReplicaOffset); - ADD_POPART_UINT64_OPTION(defaultPrefetchBufferingDepth); - ADD_POPART_UINT64_OPTION(compilationProgressTotal); - ADD_POPART_UINT64_OPTION(transitiveClosureOptimizationThreshold); - - ADD_POPART_BOOL_OPTION(batchSerializationSettings.concatOnVirtualGraphChange); - ADD_POPART_BOOL_OPTION( + ADD_POPART_ENUM_OPTION_ALIAS(autodiff_settings.stitch_strategy, + autodiffSettings.stitchStrategy, + AutodiffStitchStrategy); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.transform_context, + batchSerializationSettings.transformContext, + BatchSerializationTransformContext); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.method, + batchSerializationSettings.method, + BatchSerializationMethod); + ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.batch_schedule, + batchSerializationSettings.batchSchedule, + BatchSerializationBatchSchedule); + ADD_POPART_ENUM_OPTION_ALIAS(auto_recomputation, autoRecomputation, + RecomputationType); + ADD_POPART_ENUM_OPTION_ALIAS(merge_var_update, mergeVarUpdate, + MergeVarUpdateType); + ADD_POPART_ENUM_OPTION_ALIAS(virtual_graph_mode, virtualGraphMode, + VirtualGraphMode); + ADD_POPART_ENUM_OPTION_ALIAS(synthetic_data_mode, syntheticDataMode, + SyntheticDataMode); + ADD_POPART_ENUM_OPTION_ALIAS(subgraph_copying_strategy, + subgraphCopyingStrategy, + SubgraphCopyingStrategy); + ADD_POPART_ENUM_OPTION_ALIAS(accumulation_and_replication_reduction_type, + accumulationAndReplicationReductionType, + ReductionType); + ADD_POPART_ENUM_OPTION_ALIAS( + mean_accumulation_and_replication_reduction_strategy, + meanAccumulationAndReplicationReductionStrategy, MeanReductionStrategy); + + ADD_POPART_STRING_OPTION_ALIAS(log_dir, logDir); + ADD_POPART_STRING_OPTION_ALIAS(cache_path, cachePath); + ADD_POPART_STRING_OPTION_ALIAS(partials_type_matmuls, partialsTypeMatMuls); + ADD_POPART_STRING_OPTION_ALIAS(custom_codelet_compile_flags, + customCodeletCompileFlags); + ADD_POPART_STRING_OPTION_ALIAS(serialized_poprithms_shift_graphs_dir, + serializedPoprithmsShiftGraphsDir); + ADD_POPART_STRING_OPTION_ALIAS(kahn_tie_breaker, kahnTieBreaker); + + ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.phases, + executionPhaseSettings.phases); + ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.stages, + executionPhaseSettings.stages); + ADD_POPART_UINT64_OPTION_ALIAS(batch_serialization_settings.factor, + batchSerializationSettings.factor); + ADD_POPART_UINT64_OPTION_ALIAS(first_dot_op, firstDotOp); + ADD_POPART_UINT64_OPTION_ALIAS(final_dot_op, finalDotOp); + ADD_POPART_UINT64_OPTION_ALIAS(num_io_tiles, numIOTiles); + ADD_POPART_UINT64_OPTION_ALIAS(merge_var_update_mem_threshold, + mergeVarUpdateMemThreshold); + ADD_POPART_UINT64_OPTION_ALIAS(loose_threshold_at_peak, looseThresholdAtPeak); + ADD_POPART_UINT64_OPTION_ALIAS(accumulation_factor, accumulationFactor); + ADD_POPART_UINT64_OPTION_ALIAS(swap_limit_scheduler, swapLimitScheduler); + ADD_POPART_UINT64_OPTION_ALIAS(global_replication_factor, + globalReplicationFactor); + ADD_POPART_UINT64_OPTION_ALIAS(global_replica_offset, globalReplicaOffset); + ADD_POPART_UINT64_OPTION_ALIAS(default_prefetch_buffering_depth, + defaultPrefetchBufferingDepth); + ADD_POPART_UINT64_OPTION_ALIAS(compilation_progress_total, + compilationProgressTotal); + ADD_POPART_UINT64_OPTION_ALIAS(transitive_closure_optimization_threshold, + transitiveClosureOptimizationThreshold); + + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_virtual_graph_change, + batchSerializationSettings.concatOnVirtualGraphChange); + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_execution_phase_change, batchSerializationSettings.concatOnExecutionPhaseChange); - ADD_POPART_BOOL_OPTION( + ADD_POPART_BOOL_OPTION_ALIAS( + batch_serialization_settings.concat_on_pipeline_stage_change, batchSerializationSettings.concatOnPipelineStageChange); - ADD_POPART_BOOL_OPTION(strictOpVersions); - ADD_POPART_BOOL_OPTION(opxAliasChecking); - ADD_POPART_BOOL_OPTION(opxModifyChecking); - ADD_POPART_BOOL_OPTION(dotOpNames); - ADD_POPART_BOOL_OPTION(exportPoplarComputationGraph); - ADD_POPART_BOOL_OPTION(exportPoplarVertexGraph); - ADD_POPART_BOOL_OPTION(separateCallOpPdfs); - ADD_POPART_BOOL_OPTION(enableOutlining); - ADD_POPART_BOOL_OPTION(enableOutliningCopyCostPruning); - ADD_POPART_BOOL_OPTION(rearrangeAnchorsOnHost); - ADD_POPART_BOOL_OPTION(enablePrefetchDatastreams); - ADD_POPART_BOOL_OPTION(enableNonStableSoftmax); - ADD_POPART_BOOL_OPTION(enableReplicatedGraphs); - ADD_POPART_BOOL_OPTION(enableGradientAccumulation); - ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter); - ADD_POPART_BOOL_OPTION(enablePipelining); + ADD_POPART_BOOL_OPTION_ALIAS(strict_op_versions, strictOpVersions); + ADD_POPART_BOOL_OPTION_ALIAS(opx_alias_checking, opxAliasChecking); + ADD_POPART_BOOL_OPTION_ALIAS(opx_modify_checking, opxModifyChecking); + ADD_POPART_BOOL_OPTION_ALIAS(dot_op_names, dotOpNames); + ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_computation_graph, + exportPoplarComputationGraph); + ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_vertex_graph, + exportPoplarVertexGraph); + ADD_POPART_BOOL_OPTION_ALIAS(separate_call_op_pdfs, separateCallOpPdfs); + ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining, enableOutlining); + ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining_copy_cost_pruning, + enableOutliningCopyCostPruning); + ADD_POPART_BOOL_OPTION_ALIAS(rearrange_anchors_on_host, + rearrangeAnchorsOnHost); + ADD_POPART_BOOL_OPTION_ALIAS(enable_prefetch_datastreams, + enablePrefetchDatastreams); + ADD_POPART_BOOL_OPTION_ALIAS(enable_non_stable_softmax, + enableNonStableSoftmax); + ADD_POPART_BOOL_OPTION_ALIAS(enable_replicated_graphs, + enableReplicatedGraphs); + ADD_POPART_BOOL_OPTION_ALIAS(enable_gradient_accumulation, + enableGradientAccumulation); + ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter, + instrumentWithHardwareCycleCounter); ADD_POPART_BOOL_OPTION_ALIAS(enable_pipelining, enablePipelining); - ADD_POPART_BOOL_OPTION(disableGradAccumulationTensorStreams); - ADD_POPART_BOOL_OPTION(compileEngine); - ADD_POPART_BOOL_OPTION(constantWeights); - ADD_POPART_BOOL_OPTION(enableEngineCaching); - ADD_POPART_BOOL_OPTION(enableMergeExchange); - ADD_POPART_BOOL_OPTION(enableFloatingPointChecks); - ADD_POPART_BOOL_OPTION(enableStochasticRounding); + ADD_POPART_BOOL_OPTION_ALIAS(disable_grad_accumulation_tensor_streams, + disableGradAccumulationTensorStreams); + ADD_POPART_BOOL_OPTION_ALIAS(compile_engine, compileEngine); + ADD_POPART_BOOL_OPTION_ALIAS(constant_weights, constantWeights); + ADD_POPART_BOOL_OPTION_ALIAS(enable_engine_caching, enableEngineCaching); + ADD_POPART_BOOL_OPTION_ALIAS(enable_merge_exchange, enableMergeExchange); + ADD_POPART_BOOL_OPTION_ALIAS(enable_floating_point_checks, + enableFloatingPointChecks); ADD_POPART_BOOL_OPTION_ALIAS(enable_stochastic_rounding, enableStochasticRounding); - ADD_POPART_BOOL_OPTION(explicitRecomputation); - ADD_POPART_BOOL_OPTION(enableExplicitMainLoops); - ADD_POPART_BOOL_OPTION(useHostCopyOps); - ADD_POPART_BOOL_OPTION(aliasZeroCopy); - ADD_POPART_BOOL_OPTION(delayVarUpdates); - ADD_POPART_BOOL_OPTION(enableFullyConnectedPass); - ADD_POPART_BOOL_OPTION(enableSerializedMatmuls); - ADD_POPART_BOOL_OPTION(enableStableNorm); - ADD_POPART_BOOL_OPTION(decomposeGradSum); - ADD_POPART_BOOL_OPTION(enableDistributedReplicatedGraphs); - ADD_POPART_BOOL_OPTION(groupHostSync); - ADD_POPART_BOOL_OPTION(automaticLossScalingSettings.enabled); - ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter); - ADD_POPART_BOOL_OPTION(enableSupportedDataTypeCasting); - ADD_POPART_BOOL_OPTION(groupNormStridedChannelGrouping); - ADD_POPART_BOOL_OPTION(scheduleNonWeightUpdateGradientConsumersEarly); - - ADD_POPART_DOUBLE_OPTION(outlineSequenceBreakCost); - ADD_POPART_DOUBLE_OPTION(outlineThreshold); - ADD_POPART_DOUBLE_OPTION(timeLimitScheduler); - ADD_POPART_DOUBLE_OPTION(automaticLossScalingSettings.binEdgeLocation); - ADD_POPART_DOUBLE_OPTION( + ADD_POPART_BOOL_OPTION_ALIAS(explicit_recomputation, explicitRecomputation); + ADD_POPART_BOOL_OPTION_ALIAS(enable_explicit_main_loops, + enableExplicitMainLoops); + ADD_POPART_BOOL_OPTION_ALIAS(use_host_copy_ops, useHostCopyOps); + ADD_POPART_BOOL_OPTION_ALIAS(alias_zero_copy, aliasZeroCopy); + ADD_POPART_BOOL_OPTION_ALIAS(delay_var_updates, delayVarUpdates); + ADD_POPART_BOOL_OPTION_ALIAS(enable_fully_connected_pass, + enableFullyConnectedPass); + ADD_POPART_BOOL_OPTION_ALIAS(enable_serialized_matmuls, + enableSerializedMatmuls); + ADD_POPART_BOOL_OPTION_ALIAS(enable_stable_norm, enableStableNorm); + ADD_POPART_BOOL_OPTION_ALIAS(decompose_grad_sum, decomposeGradSum); + ADD_POPART_BOOL_OPTION_ALIAS(enable_distributed_replicated_graphs, + enableDistributedReplicatedGraphs); + ADD_POPART_BOOL_OPTION_ALIAS(group_host_sync, groupHostSync); + ADD_POPART_BOOL_OPTION_ALIAS(automatic_loss_scaling_settings.enabled, + automaticLossScalingSettings.enabled); + ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter, + instrumentWithHardwareCycleCounter); + ADD_POPART_BOOL_OPTION_ALIAS(enable_supported_data_type_casting, + enableSupportedDataTypeCasting); + ADD_POPART_BOOL_OPTION_ALIAS(group_norm_strided_channel_grouping, + groupNormStridedChannelGrouping); + ADD_POPART_BOOL_OPTION_ALIAS( + schedule_non_weight_update_gradient_consumers_early, + scheduleNonWeightUpdateGradientConsumersEarly); + + ADD_POPART_DOUBLE_OPTION_ALIAS(outline_sequence_break_cost, + outlineSequenceBreakCost); + ADD_POPART_DOUBLE_OPTION_ALIAS(outline_threshold, outlineThreshold); + ADD_POPART_DOUBLE_OPTION_ALIAS(time_limit_scheduler, timeLimitScheduler); + ADD_POPART_DOUBLE_OPTION_ALIAS( + automatic_loss_scaling_settings.bin_edge_location, + automaticLossScalingSettings.binEdgeLocation); + ADD_POPART_DOUBLE_OPTION_ALIAS( + automatic_loss_scaling_settings.threshold_upper_count_proportion, automaticLossScalingSettings.thresholdUpperCountProportion); -#undef ADD_POPART_STRING_OPTION -#undef ADD_POPART_DOUBLE_OPTION -#undef ADD_POPART_UINT64_OPTION -#undef ADD_POPART_BOOL_OPTION -#undef ADD_POPART_ENUM_OPTION #undef ADD_POPART_STRING_OPTION_ALIAS #undef ADD_POPART_DOUBLE_OPTION_ALIAS #undef ADD_POPART_UINT64_OPTION_ALIAS @@ -278,14 +308,14 @@ IpuStrategy::IpuStrategy() { }); RegisterSetter( - container_options, "dotChecks", + container_options, "dot_checks", [&](const std::pair& p) { std::uint64_t value = std::stoul(p.first); popart_options.dotChecks.insert(static_cast(value)); }); RegisterGetter( - vector_options_getter, options_type, "dotChecks", "vector", [&]() { + vector_options_getter, options_type, "dot_checks", "vector", [&]() { std::vector res; for (auto x : popart_options.dotChecks) { res.push_back(std::to_string(static_cast(x))); @@ -293,7 +323,7 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "hardwareInstrumentations", + RegisterSetter(container_options, "hardware_instrumentations", [&](const std::pair& p) { std::uint64_t value = std::stoul(p.first); popart_options.hardwareInstrumentations.insert( @@ -301,8 +331,8 @@ IpuStrategy::IpuStrategy() { }); RegisterGetter( - vector_options_getter, options_type, "hardwareInstrumentations", "vector", - [&]() { + vector_options_getter, options_type, "hardware_instrumentations", + "vector", [&]() { std::vector res; for (auto x : popart_options.hardwareInstrumentations) { res.push_back(std::to_string(static_cast(x))); @@ -310,12 +340,12 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "customCodelets", + RegisterSetter(container_options, "custom_codelets", [&](const std::pair& p) { popart_options.customCodelets.push_back(p.first); }); - RegisterGetter(vector_options_getter, options_type, "customCodelets", + RegisterGetter(vector_options_getter, options_type, "custom_codelets", "vector", [&]() { std::vector res; for (auto x : popart_options.customCodelets) { @@ -324,44 +354,44 @@ IpuStrategy::IpuStrategy() { return res; }); - RegisterSetter(container_options, "engineOptions", + RegisterSetter(container_options, "engine_options", [&](const std::pair& p) { popart_options.engineOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "engineOptions", "map", + RegisterGetter(map_options_getter, options_type, "engine_options", "map", [&]() { return popart_options.engineOptions; }); - RegisterSetter(container_options, "reportOptions", + RegisterSetter(container_options, "report_options", [&](const std::pair& p) { popart_options.reportOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "reportOptions", "map", + RegisterGetter(map_options_getter, options_type, "report_options", "map", [&]() { return popart_options.reportOptions; }); - RegisterSetter(container_options, "convolutionOptions", + RegisterSetter(container_options, "convolution_options", [&](const std::pair& p) { popart_options.convolutionOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "convolutionOptions", "map", + RegisterGetter(map_options_getter, options_type, "convolution_options", "map", [&]() { return popart_options.convolutionOptions; }); - RegisterSetter(container_options, "lstmOptions", + RegisterSetter(container_options, "lstm_options", [&](const std::pair& p) { popart_options.lstmOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "lstmOptions", "map", + RegisterGetter(map_options_getter, options_type, "lstm_options", "map", [&]() { return popart_options.lstmOptions; }); - RegisterSetter(container_options, "gclOptions", + RegisterSetter(container_options, "gcl_options", [&](const std::pair& p) { popart_options.gclOptions.emplace(p); }); - RegisterGetter(map_options_getter, options_type, "gclOptions", "map", + RegisterGetter(map_options_getter, options_type, "gcl_options", "map", [&]() { return popart_options.gclOptions; }); } @@ -415,21 +445,21 @@ void IpuStrategy::SetTensorLocation(const std::string& tensor, "Unknown tensor location: %s", tensor)); } - if (opt == "minElementsForOffChip") { + if (opt == "min_elements_for_off_chip") { settings->minElementsForOffChip = value; - } else if (opt == "minElementsForReplicatedTensorSharding") { + } else if (opt == "min_elements_for_replicated_tensor_sharding") { settings->minElementsForReplicatedTensorSharding = value; - } else if (opt == "onChip") { + } else if (opt == "on_chip") { settings->location.storage = value > 0 ? popart::TensorStorage::OnChip : popart::TensorStorage::OffChip; - } else if (opt == "useReplicatedTensorSharding") { + } else if (opt == "use_replicated_tensor_sharding") { settings->location.replicatedTensorSharding = value > 0 ? popart::ReplicatedTensorSharding::On : popart::ReplicatedTensorSharding::Off; - } else if (opt == "useIOTilesToLoad") { + } else if (opt == "use_io_tiles_to_load") { settings->location.loadTileSet = value > 0 ? popart::TileSet::IO : popart::TileSet::Compute; - } else if (opt == "useIOTilesToStore") { + } else if (opt == "use_io_tiles_to_store") { settings->location.storageTileSet = value > 0 ? popart::TileSet::IO : popart::TileSet::Compute; } else { @@ -464,6 +494,20 @@ std::string IpuStrategy::GetOptionType(const std::string& option) { return options_type[option]; } +std::vector IpuStrategy::GetAllOptionNames() { + std::vector names; + for (auto& option : options_getter) { + names.push_back(option.first); + } + for (auto& option : vector_options_getter) { + names.push_back(option.first); + } + for (auto& option : map_options_getter) { + names.push_back(option.first); + } + return names; +} + void IpuStrategy::EnablePattern(const std::string& t) { VLOG(10) << "enable popart pattern: " << t; popart_patterns.enablePattern(t, true); diff --git a/paddle/fluid/platform/device/ipu/ipu_strategy.h b/paddle/fluid/platform/device/ipu/ipu_strategy.h index 64436dc14fec3393b0a2a4473ad436d7d08f5217..571fb1e163718388a779e128fb6aaf76659d7183 100644 --- a/paddle/fluid/platform/device/ipu/ipu_strategy.h +++ b/paddle/fluid/platform/device/ipu/ipu_strategy.h @@ -24,7 +24,8 @@ namespace paddle { namespace platform { namespace ipu { -struct IpuStrategy { +class IpuStrategy { + public: IpuStrategy(); // TODO(alleng) create PaddleOptions @@ -75,22 +76,30 @@ struct IpuStrategy { // custom ops std::vector custom_ops; - private: - std::map> bool_options; - std::map> uint64_options; - std::map> double_options; - std::map> string_options; - std::map)>> - container_options; + public: + void AddBoolOption(const std::string &option, bool value); + void AddUint64Option(const std::string &option, std::uint64_t value); + void AddDoubleOption(const std::string &option, double value); + void AddStringOption(const std::string &option, const std::string &value); + void InsertStringOption(const std::string &option, const std::string &value); + void InsertStringPairOption(const std::string &option, const std::string &key, + const std::string &value); + void SetTensorLocation(const std::string &tensor, const std::string &option, + std::uint64_t value); + void AddCustomOp(const std::string &paddle_op, const std::string &popart_op, + const std::string &domain, int version); - std::map> options_getter; - std::map()>> - vector_options_getter; - std::map()>> - map_options_getter; - std::map options_type; + std::string GetOption(const std::string &); + std::vector GetVectorOption(const std::string &); + std::map GetMapOption(const std::string &); + std::string GetOptionType(const std::string &); + std::vector GetAllOptionNames(); + + void EnablePattern(const std::string &t); + void DisablePattern(const std::string &t); + const bool IsPatternEnabled(const std::string &t); + private: template void set( const std::string &key, ValueType value, @@ -117,27 +126,20 @@ struct IpuStrategy { return it->second(); } - public: - void AddBoolOption(const std::string &option, bool value); - void AddUint64Option(const std::string &option, std::uint64_t value); - void AddDoubleOption(const std::string &option, double value); - void AddStringOption(const std::string &option, const std::string &value); - void InsertStringOption(const std::string &option, const std::string &value); - void InsertStringPairOption(const std::string &option, const std::string &key, - const std::string &value); - void SetTensorLocation(const std::string &tensor, const std::string &option, - std::uint64_t value); - void AddCustomOp(const std::string &paddle_op, const std::string &popart_op, - const std::string &domain, int version); - - std::string GetOption(const std::string &); - std::vector GetVectorOption(const std::string &); - std::map GetMapOption(const std::string &); - std::string GetOptionType(const std::string &); + std::map> bool_options; + std::map> uint64_options; + std::map> double_options; + std::map> string_options; + std::map)>> + container_options; - void EnablePattern(const std::string &t); - void DisablePattern(const std::string &t); - const bool IsPatternEnabled(const std::string &t); + std::map> options_getter; + std::map()>> + vector_options_getter; + std::map()>> + map_options_getter; + std::map options_type; }; } // namespace ipu diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6e553ad2e60e292881fa8bb0294ea2a247656b67..3d8815e2eb61b53a6c8447fc8ce09a9c113963f2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3919,6 +3919,8 @@ All parameter, weight, gradient are variables in Paddle. } return res; }) + .def("get_all_option_names", + &platform::ipu::IpuStrategy::GetAllOptionNames) .def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern) .def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern) .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); diff --git a/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..959700ad743b40420200b56055354279386a9a7c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ipu/CMakeLists.txt @@ -0,0 +1,8 @@ +if(WITH_IPU) + file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") + string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + + foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) + endforeach(TEST_OP) +endif()