未验证 提交 69ab2700 编写于 作者: A Allen Guo 提交者: GitHub

fix compiling and running with ipu (#39920)

上级 09039636
......@@ -125,6 +125,15 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
#ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
}
#endif
return phi::KernelKey();
}
......
......@@ -120,121 +120,151 @@ IpuStrategy::IpuStrategy() {
RegisterGetter(options_getter, options_type, #name, "string", \
[&]() { return popart_options.aliased_name; })
#define ADD_POPART_ENUM_OPTION(name, EnumType) \
ADD_POPART_ENUM_OPTION_ALIAS(name, name, EnumType)
#define ADD_POPART_BOOL_OPTION(name) ADD_POPART_BOOL_OPTION_ALIAS(name, name)
#define ADD_POPART_UINT64_OPTION(name) \
ADD_POPART_UINT64_OPTION_ALIAS(name, name)
#define ADD_POPART_DOUBLE_OPTION(name) \
ADD_POPART_DOUBLE_OPTION_ALIAS(name, name)
#define ADD_POPART_STRING_OPTION(name) \
ADD_POPART_STRING_OPTION_ALIAS(name, name)
ADD_POPART_ENUM_OPTION(autodiffSettings.stitchStrategy,
ADD_POPART_ENUM_OPTION_ALIAS(autodiff_settings.stitch_strategy,
autodiffSettings.stitchStrategy,
AutodiffStitchStrategy);
ADD_POPART_ENUM_OPTION(batchSerializationSettings.transformContext,
ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.transform_context,
batchSerializationSettings.transformContext,
BatchSerializationTransformContext);
ADD_POPART_ENUM_OPTION(batchSerializationSettings.method,
ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.method,
batchSerializationSettings.method,
BatchSerializationMethod);
ADD_POPART_ENUM_OPTION(batchSerializationSettings.batchSchedule,
ADD_POPART_ENUM_OPTION_ALIAS(batch_serialization_settings.batch_schedule,
batchSerializationSettings.batchSchedule,
BatchSerializationBatchSchedule);
ADD_POPART_ENUM_OPTION(autoRecomputation, RecomputationType);
ADD_POPART_ENUM_OPTION(mergeVarUpdate, MergeVarUpdateType);
ADD_POPART_ENUM_OPTION(virtualGraphMode, VirtualGraphMode);
ADD_POPART_ENUM_OPTION(syntheticDataMode, SyntheticDataMode);
ADD_POPART_ENUM_OPTION(subgraphCopyingStrategy, SubgraphCopyingStrategy);
ADD_POPART_ENUM_OPTION(accumulationAndReplicationReductionType,
ADD_POPART_ENUM_OPTION_ALIAS(auto_recomputation, autoRecomputation,
RecomputationType);
ADD_POPART_ENUM_OPTION_ALIAS(merge_var_update, mergeVarUpdate,
MergeVarUpdateType);
ADD_POPART_ENUM_OPTION_ALIAS(virtual_graph_mode, virtualGraphMode,
VirtualGraphMode);
ADD_POPART_ENUM_OPTION_ALIAS(synthetic_data_mode, syntheticDataMode,
SyntheticDataMode);
ADD_POPART_ENUM_OPTION_ALIAS(subgraph_copying_strategy,
subgraphCopyingStrategy,
SubgraphCopyingStrategy);
ADD_POPART_ENUM_OPTION_ALIAS(accumulation_and_replication_reduction_type,
accumulationAndReplicationReductionType,
ReductionType);
ADD_POPART_ENUM_OPTION(meanAccumulationAndReplicationReductionStrategy,
MeanReductionStrategy);
ADD_POPART_STRING_OPTION(logDir);
ADD_POPART_STRING_OPTION(cachePath);
ADD_POPART_STRING_OPTION(partialsTypeMatMuls);
ADD_POPART_STRING_OPTION(customCodeletCompileFlags);
ADD_POPART_STRING_OPTION(serializedPoprithmsShiftGraphsDir);
ADD_POPART_STRING_OPTION(kahnTieBreaker);
ADD_POPART_UINT64_OPTION(executionPhaseSettings.phases);
ADD_POPART_UINT64_OPTION(executionPhaseSettings.stages);
ADD_POPART_UINT64_OPTION(batchSerializationSettings.factor);
ADD_POPART_UINT64_OPTION(firstDotOp);
ADD_POPART_UINT64_OPTION(finalDotOp);
ADD_POPART_UINT64_OPTION(numIOTiles);
ADD_POPART_UINT64_OPTION(mergeVarUpdateMemThreshold);
ADD_POPART_UINT64_OPTION(looseThresholdAtPeak);
ADD_POPART_UINT64_OPTION(accumulationFactor);
ADD_POPART_UINT64_OPTION(swapLimitScheduler);
ADD_POPART_UINT64_OPTION(globalReplicationFactor);
ADD_POPART_UINT64_OPTION(globalReplicaOffset);
ADD_POPART_UINT64_OPTION(defaultPrefetchBufferingDepth);
ADD_POPART_UINT64_OPTION(compilationProgressTotal);
ADD_POPART_UINT64_OPTION(transitiveClosureOptimizationThreshold);
ADD_POPART_BOOL_OPTION(batchSerializationSettings.concatOnVirtualGraphChange);
ADD_POPART_BOOL_OPTION(
ADD_POPART_ENUM_OPTION_ALIAS(
mean_accumulation_and_replication_reduction_strategy,
meanAccumulationAndReplicationReductionStrategy, MeanReductionStrategy);
ADD_POPART_STRING_OPTION_ALIAS(log_dir, logDir);
ADD_POPART_STRING_OPTION_ALIAS(cache_path, cachePath);
ADD_POPART_STRING_OPTION_ALIAS(partials_type_matmuls, partialsTypeMatMuls);
ADD_POPART_STRING_OPTION_ALIAS(custom_codelet_compile_flags,
customCodeletCompileFlags);
ADD_POPART_STRING_OPTION_ALIAS(serialized_poprithms_shift_graphs_dir,
serializedPoprithmsShiftGraphsDir);
ADD_POPART_STRING_OPTION_ALIAS(kahn_tie_breaker, kahnTieBreaker);
ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.phases,
executionPhaseSettings.phases);
ADD_POPART_UINT64_OPTION_ALIAS(execution_phase_settings.stages,
executionPhaseSettings.stages);
ADD_POPART_UINT64_OPTION_ALIAS(batch_serialization_settings.factor,
batchSerializationSettings.factor);
ADD_POPART_UINT64_OPTION_ALIAS(first_dot_op, firstDotOp);
ADD_POPART_UINT64_OPTION_ALIAS(final_dot_op, finalDotOp);
ADD_POPART_UINT64_OPTION_ALIAS(num_io_tiles, numIOTiles);
ADD_POPART_UINT64_OPTION_ALIAS(merge_var_update_mem_threshold,
mergeVarUpdateMemThreshold);
ADD_POPART_UINT64_OPTION_ALIAS(loose_threshold_at_peak, looseThresholdAtPeak);
ADD_POPART_UINT64_OPTION_ALIAS(accumulation_factor, accumulationFactor);
ADD_POPART_UINT64_OPTION_ALIAS(swap_limit_scheduler, swapLimitScheduler);
ADD_POPART_UINT64_OPTION_ALIAS(global_replication_factor,
globalReplicationFactor);
ADD_POPART_UINT64_OPTION_ALIAS(global_replica_offset, globalReplicaOffset);
ADD_POPART_UINT64_OPTION_ALIAS(default_prefetch_buffering_depth,
defaultPrefetchBufferingDepth);
ADD_POPART_UINT64_OPTION_ALIAS(compilation_progress_total,
compilationProgressTotal);
ADD_POPART_UINT64_OPTION_ALIAS(transitive_closure_optimization_threshold,
transitiveClosureOptimizationThreshold);
ADD_POPART_BOOL_OPTION_ALIAS(
batch_serialization_settings.concat_on_virtual_graph_change,
batchSerializationSettings.concatOnVirtualGraphChange);
ADD_POPART_BOOL_OPTION_ALIAS(
batch_serialization_settings.concat_on_execution_phase_change,
batchSerializationSettings.concatOnExecutionPhaseChange);
ADD_POPART_BOOL_OPTION(
ADD_POPART_BOOL_OPTION_ALIAS(
batch_serialization_settings.concat_on_pipeline_stage_change,
batchSerializationSettings.concatOnPipelineStageChange);
ADD_POPART_BOOL_OPTION(strictOpVersions);
ADD_POPART_BOOL_OPTION(opxAliasChecking);
ADD_POPART_BOOL_OPTION(opxModifyChecking);
ADD_POPART_BOOL_OPTION(dotOpNames);
ADD_POPART_BOOL_OPTION(exportPoplarComputationGraph);
ADD_POPART_BOOL_OPTION(exportPoplarVertexGraph);
ADD_POPART_BOOL_OPTION(separateCallOpPdfs);
ADD_POPART_BOOL_OPTION(enableOutlining);
ADD_POPART_BOOL_OPTION(enableOutliningCopyCostPruning);
ADD_POPART_BOOL_OPTION(rearrangeAnchorsOnHost);
ADD_POPART_BOOL_OPTION(enablePrefetchDatastreams);
ADD_POPART_BOOL_OPTION(enableNonStableSoftmax);
ADD_POPART_BOOL_OPTION(enableReplicatedGraphs);
ADD_POPART_BOOL_OPTION(enableGradientAccumulation);
ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter);
ADD_POPART_BOOL_OPTION(enablePipelining);
ADD_POPART_BOOL_OPTION_ALIAS(strict_op_versions, strictOpVersions);
ADD_POPART_BOOL_OPTION_ALIAS(opx_alias_checking, opxAliasChecking);
ADD_POPART_BOOL_OPTION_ALIAS(opx_modify_checking, opxModifyChecking);
ADD_POPART_BOOL_OPTION_ALIAS(dot_op_names, dotOpNames);
ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_computation_graph,
exportPoplarComputationGraph);
ADD_POPART_BOOL_OPTION_ALIAS(export_poplar_vertex_graph,
exportPoplarVertexGraph);
ADD_POPART_BOOL_OPTION_ALIAS(separate_call_op_pdfs, separateCallOpPdfs);
ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining, enableOutlining);
ADD_POPART_BOOL_OPTION_ALIAS(enable_outlining_copy_cost_pruning,
enableOutliningCopyCostPruning);
ADD_POPART_BOOL_OPTION_ALIAS(rearrange_anchors_on_host,
rearrangeAnchorsOnHost);
ADD_POPART_BOOL_OPTION_ALIAS(enable_prefetch_datastreams,
enablePrefetchDatastreams);
ADD_POPART_BOOL_OPTION_ALIAS(enable_non_stable_softmax,
enableNonStableSoftmax);
ADD_POPART_BOOL_OPTION_ALIAS(enable_replicated_graphs,
enableReplicatedGraphs);
ADD_POPART_BOOL_OPTION_ALIAS(enable_gradient_accumulation,
enableGradientAccumulation);
ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter,
instrumentWithHardwareCycleCounter);
ADD_POPART_BOOL_OPTION_ALIAS(enable_pipelining, enablePipelining);
ADD_POPART_BOOL_OPTION(disableGradAccumulationTensorStreams);
ADD_POPART_BOOL_OPTION(compileEngine);
ADD_POPART_BOOL_OPTION(constantWeights);
ADD_POPART_BOOL_OPTION(enableEngineCaching);
ADD_POPART_BOOL_OPTION(enableMergeExchange);
ADD_POPART_BOOL_OPTION(enableFloatingPointChecks);
ADD_POPART_BOOL_OPTION(enableStochasticRounding);
ADD_POPART_BOOL_OPTION_ALIAS(disable_grad_accumulation_tensor_streams,
disableGradAccumulationTensorStreams);
ADD_POPART_BOOL_OPTION_ALIAS(compile_engine, compileEngine);
ADD_POPART_BOOL_OPTION_ALIAS(constant_weights, constantWeights);
ADD_POPART_BOOL_OPTION_ALIAS(enable_engine_caching, enableEngineCaching);
ADD_POPART_BOOL_OPTION_ALIAS(enable_merge_exchange, enableMergeExchange);
ADD_POPART_BOOL_OPTION_ALIAS(enable_floating_point_checks,
enableFloatingPointChecks);
ADD_POPART_BOOL_OPTION_ALIAS(enable_stochastic_rounding,
enableStochasticRounding);
ADD_POPART_BOOL_OPTION(explicitRecomputation);
ADD_POPART_BOOL_OPTION(enableExplicitMainLoops);
ADD_POPART_BOOL_OPTION(useHostCopyOps);
ADD_POPART_BOOL_OPTION(aliasZeroCopy);
ADD_POPART_BOOL_OPTION(delayVarUpdates);
ADD_POPART_BOOL_OPTION(enableFullyConnectedPass);
ADD_POPART_BOOL_OPTION(enableSerializedMatmuls);
ADD_POPART_BOOL_OPTION(enableStableNorm);
ADD_POPART_BOOL_OPTION(decomposeGradSum);
ADD_POPART_BOOL_OPTION(enableDistributedReplicatedGraphs);
ADD_POPART_BOOL_OPTION(groupHostSync);
ADD_POPART_BOOL_OPTION(automaticLossScalingSettings.enabled);
ADD_POPART_BOOL_OPTION(instrumentWithHardwareCycleCounter);
ADD_POPART_BOOL_OPTION(enableSupportedDataTypeCasting);
ADD_POPART_BOOL_OPTION(groupNormStridedChannelGrouping);
ADD_POPART_BOOL_OPTION(scheduleNonWeightUpdateGradientConsumersEarly);
ADD_POPART_DOUBLE_OPTION(outlineSequenceBreakCost);
ADD_POPART_DOUBLE_OPTION(outlineThreshold);
ADD_POPART_DOUBLE_OPTION(timeLimitScheduler);
ADD_POPART_DOUBLE_OPTION(automaticLossScalingSettings.binEdgeLocation);
ADD_POPART_DOUBLE_OPTION(
ADD_POPART_BOOL_OPTION_ALIAS(explicit_recomputation, explicitRecomputation);
ADD_POPART_BOOL_OPTION_ALIAS(enable_explicit_main_loops,
enableExplicitMainLoops);
ADD_POPART_BOOL_OPTION_ALIAS(use_host_copy_ops, useHostCopyOps);
ADD_POPART_BOOL_OPTION_ALIAS(alias_zero_copy, aliasZeroCopy);
ADD_POPART_BOOL_OPTION_ALIAS(delay_var_updates, delayVarUpdates);
ADD_POPART_BOOL_OPTION_ALIAS(enable_fully_connected_pass,
enableFullyConnectedPass);
ADD_POPART_BOOL_OPTION_ALIAS(enable_serialized_matmuls,
enableSerializedMatmuls);
ADD_POPART_BOOL_OPTION_ALIAS(enable_stable_norm, enableStableNorm);
ADD_POPART_BOOL_OPTION_ALIAS(decompose_grad_sum, decomposeGradSum);
ADD_POPART_BOOL_OPTION_ALIAS(enable_distributed_replicated_graphs,
enableDistributedReplicatedGraphs);
ADD_POPART_BOOL_OPTION_ALIAS(group_host_sync, groupHostSync);
ADD_POPART_BOOL_OPTION_ALIAS(automatic_loss_scaling_settings.enabled,
automaticLossScalingSettings.enabled);
ADD_POPART_BOOL_OPTION_ALIAS(instrument_with_hardware_cycle_counter,
instrumentWithHardwareCycleCounter);
ADD_POPART_BOOL_OPTION_ALIAS(enable_supported_data_type_casting,
enableSupportedDataTypeCasting);
ADD_POPART_BOOL_OPTION_ALIAS(group_norm_strided_channel_grouping,
groupNormStridedChannelGrouping);
ADD_POPART_BOOL_OPTION_ALIAS(
schedule_non_weight_update_gradient_consumers_early,
scheduleNonWeightUpdateGradientConsumersEarly);
ADD_POPART_DOUBLE_OPTION_ALIAS(outline_sequence_break_cost,
outlineSequenceBreakCost);
ADD_POPART_DOUBLE_OPTION_ALIAS(outline_threshold, outlineThreshold);
ADD_POPART_DOUBLE_OPTION_ALIAS(time_limit_scheduler, timeLimitScheduler);
ADD_POPART_DOUBLE_OPTION_ALIAS(
automatic_loss_scaling_settings.bin_edge_location,
automaticLossScalingSettings.binEdgeLocation);
ADD_POPART_DOUBLE_OPTION_ALIAS(
automatic_loss_scaling_settings.threshold_upper_count_proportion,
automaticLossScalingSettings.thresholdUpperCountProportion);
#undef ADD_POPART_STRING_OPTION
#undef ADD_POPART_DOUBLE_OPTION
#undef ADD_POPART_UINT64_OPTION
#undef ADD_POPART_BOOL_OPTION
#undef ADD_POPART_ENUM_OPTION
#undef ADD_POPART_STRING_OPTION_ALIAS
#undef ADD_POPART_DOUBLE_OPTION_ALIAS
#undef ADD_POPART_UINT64_OPTION_ALIAS
......@@ -278,14 +308,14 @@ IpuStrategy::IpuStrategy() {
});
RegisterSetter(
container_options, "dotChecks",
container_options, "dot_checks",
[&](const std::pair<std::string, std::string>& p) {
std::uint64_t value = std::stoul(p.first);
popart_options.dotChecks.insert(static_cast<popart::DotCheck>(value));
});
RegisterGetter(
vector_options_getter, options_type, "dotChecks", "vector", [&]() {
vector_options_getter, options_type, "dot_checks", "vector", [&]() {
std::vector<std::string> res;
for (auto x : popart_options.dotChecks) {
res.push_back(std::to_string(static_cast<std::uint64_t>(x)));
......@@ -293,7 +323,7 @@ IpuStrategy::IpuStrategy() {
return res;
});
RegisterSetter(container_options, "hardwareInstrumentations",
RegisterSetter(container_options, "hardware_instrumentations",
[&](const std::pair<std::string, std::string>& p) {
std::uint64_t value = std::stoul(p.first);
popart_options.hardwareInstrumentations.insert(
......@@ -301,8 +331,8 @@ IpuStrategy::IpuStrategy() {
});
RegisterGetter(
vector_options_getter, options_type, "hardwareInstrumentations", "vector",
[&]() {
vector_options_getter, options_type, "hardware_instrumentations",
"vector", [&]() {
std::vector<std::string> res;
for (auto x : popart_options.hardwareInstrumentations) {
res.push_back(std::to_string(static_cast<std::uint64_t>(x)));
......@@ -310,12 +340,12 @@ IpuStrategy::IpuStrategy() {
return res;
});
RegisterSetter(container_options, "customCodelets",
RegisterSetter(container_options, "custom_codelets",
[&](const std::pair<std::string, std::string>& p) {
popart_options.customCodelets.push_back(p.first);
});
RegisterGetter(vector_options_getter, options_type, "customCodelets",
RegisterGetter(vector_options_getter, options_type, "custom_codelets",
"vector", [&]() {
std::vector<std::string> res;
for (auto x : popart_options.customCodelets) {
......@@ -324,44 +354,44 @@ IpuStrategy::IpuStrategy() {
return res;
});
RegisterSetter(container_options, "engineOptions",
RegisterSetter(container_options, "engine_options",
[&](const std::pair<std::string, std::string>& p) {
popart_options.engineOptions.emplace(p);
});
RegisterGetter(map_options_getter, options_type, "engineOptions", "map",
RegisterGetter(map_options_getter, options_type, "engine_options", "map",
[&]() { return popart_options.engineOptions; });
RegisterSetter(container_options, "reportOptions",
RegisterSetter(container_options, "report_options",
[&](const std::pair<std::string, std::string>& p) {
popart_options.reportOptions.emplace(p);
});
RegisterGetter(map_options_getter, options_type, "reportOptions", "map",
RegisterGetter(map_options_getter, options_type, "report_options", "map",
[&]() { return popart_options.reportOptions; });
RegisterSetter(container_options, "convolutionOptions",
RegisterSetter(container_options, "convolution_options",
[&](const std::pair<std::string, std::string>& p) {
popart_options.convolutionOptions.emplace(p);
});
RegisterGetter(map_options_getter, options_type, "convolutionOptions", "map",
RegisterGetter(map_options_getter, options_type, "convolution_options", "map",
[&]() { return popart_options.convolutionOptions; });
RegisterSetter(container_options, "lstmOptions",
RegisterSetter(container_options, "lstm_options",
[&](const std::pair<std::string, std::string>& p) {
popart_options.lstmOptions.emplace(p);
});
RegisterGetter(map_options_getter, options_type, "lstmOptions", "map",
RegisterGetter(map_options_getter, options_type, "lstm_options", "map",
[&]() { return popart_options.lstmOptions; });
RegisterSetter(container_options, "gclOptions",
RegisterSetter(container_options, "gcl_options",
[&](const std::pair<std::string, std::string>& p) {
popart_options.gclOptions.emplace(p);
});
RegisterGetter(map_options_getter, options_type, "gclOptions", "map",
RegisterGetter(map_options_getter, options_type, "gcl_options", "map",
[&]() { return popart_options.gclOptions; });
}
......@@ -415,21 +445,21 @@ void IpuStrategy::SetTensorLocation(const std::string& tensor,
"Unknown tensor location: %s", tensor));
}
if (opt == "minElementsForOffChip") {
if (opt == "min_elements_for_off_chip") {
settings->minElementsForOffChip = value;
} else if (opt == "minElementsForReplicatedTensorSharding") {
} else if (opt == "min_elements_for_replicated_tensor_sharding") {
settings->minElementsForReplicatedTensorSharding = value;
} else if (opt == "onChip") {
} else if (opt == "on_chip") {
settings->location.storage = value > 0 ? popart::TensorStorage::OnChip
: popart::TensorStorage::OffChip;
} else if (opt == "useReplicatedTensorSharding") {
} else if (opt == "use_replicated_tensor_sharding") {
settings->location.replicatedTensorSharding =
value > 0 ? popart::ReplicatedTensorSharding::On
: popart::ReplicatedTensorSharding::Off;
} else if (opt == "useIOTilesToLoad") {
} else if (opt == "use_io_tiles_to_load") {
settings->location.loadTileSet =
value > 0 ? popart::TileSet::IO : popart::TileSet::Compute;
} else if (opt == "useIOTilesToStore") {
} else if (opt == "use_io_tiles_to_store") {
settings->location.storageTileSet =
value > 0 ? popart::TileSet::IO : popart::TileSet::Compute;
} else {
......@@ -464,6 +494,20 @@ std::string IpuStrategy::GetOptionType(const std::string& option) {
return options_type[option];
}
std::vector<std::string> IpuStrategy::GetAllOptionNames() {
std::vector<std::string> names;
for (auto& option : options_getter) {
names.push_back(option.first);
}
for (auto& option : vector_options_getter) {
names.push_back(option.first);
}
for (auto& option : map_options_getter) {
names.push_back(option.first);
}
return names;
}
void IpuStrategy::EnablePattern(const std::string& t) {
VLOG(10) << "enable popart pattern: " << t;
popart_patterns.enablePattern(t, true);
......
......@@ -24,7 +24,8 @@ namespace paddle {
namespace platform {
namespace ipu {
struct IpuStrategy {
class IpuStrategy {
public:
IpuStrategy();
// TODO(alleng) create PaddleOptions
......@@ -75,22 +76,30 @@ struct IpuStrategy {
// custom ops
std::vector<IpuCustomOpIdentifier> custom_ops;
private:
std::map<std::string, std::function<void(bool)>> bool_options;
std::map<std::string, std::function<void(std::uint64_t)>> uint64_options;
std::map<std::string, std::function<void(double)>> double_options;
std::map<std::string, std::function<void(std::string)>> string_options;
std::map<std::string,
std::function<void(std::pair<std::string, std::string>)>>
container_options;
public:
void AddBoolOption(const std::string &option, bool value);
void AddUint64Option(const std::string &option, std::uint64_t value);
void AddDoubleOption(const std::string &option, double value);
void AddStringOption(const std::string &option, const std::string &value);
void InsertStringOption(const std::string &option, const std::string &value);
void InsertStringPairOption(const std::string &option, const std::string &key,
const std::string &value);
void SetTensorLocation(const std::string &tensor, const std::string &option,
std::uint64_t value);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
std::map<std::string, std::function<std::string()>> options_getter;
std::map<std::string, std::function<std::vector<std::string>()>>
vector_options_getter;
std::map<std::string, std::function<std::map<std::string, std::string>()>>
map_options_getter;
std::map<std::string, std::string> options_type;
std::string GetOption(const std::string &);
std::vector<std::string> GetVectorOption(const std::string &);
std::map<std::string, std::string> GetMapOption(const std::string &);
std::string GetOptionType(const std::string &);
std::vector<std::string> GetAllOptionNames();
void EnablePattern(const std::string &t);
void DisablePattern(const std::string &t);
const bool IsPatternEnabled(const std::string &t);
private:
template <typename ValueType>
void set(
const std::string &key, ValueType value,
......@@ -117,27 +126,20 @@ struct IpuStrategy {
return it->second();
}
public:
void AddBoolOption(const std::string &option, bool value);
void AddUint64Option(const std::string &option, std::uint64_t value);
void AddDoubleOption(const std::string &option, double value);
void AddStringOption(const std::string &option, const std::string &value);
void InsertStringOption(const std::string &option, const std::string &value);
void InsertStringPairOption(const std::string &option, const std::string &key,
const std::string &value);
void SetTensorLocation(const std::string &tensor, const std::string &option,
std::uint64_t value);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
std::string GetOption(const std::string &);
std::vector<std::string> GetVectorOption(const std::string &);
std::map<std::string, std::string> GetMapOption(const std::string &);
std::string GetOptionType(const std::string &);
std::map<std::string, std::function<void(bool)>> bool_options;
std::map<std::string, std::function<void(std::uint64_t)>> uint64_options;
std::map<std::string, std::function<void(double)>> double_options;
std::map<std::string, std::function<void(std::string)>> string_options;
std::map<std::string,
std::function<void(std::pair<std::string, std::string>)>>
container_options;
void EnablePattern(const std::string &t);
void DisablePattern(const std::string &t);
const bool IsPatternEnabled(const std::string &t);
std::map<std::string, std::function<std::string()>> options_getter;
std::map<std::string, std::function<std::vector<std::string>()>>
vector_options_getter;
std::map<std::string, std::function<std::map<std::string, std::string>()>>
map_options_getter;
std::map<std::string, std::string> options_type;
};
} // namespace ipu
......
......@@ -3919,6 +3919,8 @@ All parameter, weight, gradient are variables in Paddle.
}
return res;
})
.def("get_all_option_names",
&platform::ipu::IpuStrategy::GetAllOptionNames)
.def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern)
.def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern)
.def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled);
......
if(WITH_IPU)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册