未验证 提交 6ec89eeb 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] merge recent changes (#42078)

* merge recent changes

* fix setting pipline
上级 20286ae7
......@@ -185,12 +185,9 @@ void Compiler::RegisterOpFunc() {
auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
PushNameScope(op_desc); \
NameScopeHelper ns_helper(op_desc, builder_.get()); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
PopNameScope(op_desc); \
SetIpuIndexStage(output_ids, op_desc); \
SetAMPAttributes(output_ids, op_desc); \
SetSerializeAttributes(output_ids, op_desc); \
PostLower(output_ids, op_desc); \
InsertTensors(output_names, output_ids); \
}}, // NOLINT
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
......@@ -273,10 +270,9 @@ void Compiler::LowerConstants(const Scope* scope) {
popart::TensorInfo tensor_info(PdDataType2PopartType(tensor->dtype()),
shape);
const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
PushNameScope(op_desc);
NameScopeHelper ns_helper(op_desc, builder_.get());
popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
PopNameScope(op_desc);
SetIpuIndexStage(result, op_desc);
PostLower(result, op_desc);
resources_->tensors.emplace(tensor_name, result);
}
}
......@@ -285,42 +281,42 @@ void Compiler::LowerConstants(const Scope* scope) {
void Compiler::LowerWeights(const Scope* scope) {
VLOG(10) << "enter Compiler::LowerWeights";
// at this step, the graph doesn't contains optimizer related states
// At this step, the graph doesn't contains optimizer related states
for (auto id : graph_helper_->sorted_vars_id) {
auto* node = graph_helper_->nodes_id_map[id];
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
if (node->Var()->Persistable() && node->inputs.empty()) {
auto var_name = node->Var()->Name();
if (resources_->tensors.count(var_name) != 0) {
VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
continue;
}
if (var_name.rfind("learning_rate", 0) == 0) {
VLOG(10) << "skip learning_rate_var: " << var_name;
continue;
}
VLOG(10) << "lowering weight: " << var_name;
auto var = scope->FindVar(var_name);
if (var) {
auto tensor = var->Get<framework::LoDTensor>();
auto dtype = PdDataType2PopartType(tensor.dtype());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data(), tensor_info};
if (!node->outputs.empty()) {
auto op_node = node->outputs[0];
PushNameScope(op_node->Op());
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
PopNameScope(op_node->Op());
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(var_name);
}
}
// Weights are var node and Persistable
if (node->IsVar() && !node->IsCtrlVar() && node->Var() &&
node->Var()->Persistable()) {
// Weights are Parameter in training mode
if (ipu_strategy_->is_training && !node->Var()->IsParameter()) {
continue;
}
auto var_name = node->Var()->Name();
// Some op has same input and output tensor, like batchnorm
if (resources_->tensors.count(var_name) != 0) {
VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
continue;
}
VLOG(10) << "lowering weight: " << var_name;
auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Tensor %s is not found in the scope",
var_name));
auto tensor = var->Get<framework::LoDTensor>();
auto dtype = PdDataType2PopartType(tensor.dtype());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i));
}
popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data(), tensor_info};
if (!node->outputs.empty()) {
auto op_node = node->outputs[0];
NameScopeHelper ns_helper(op_node->Op(), builder_.get());
popart::TensorId result =
builder_->addInitializedInputTensor(const_data, var_name);
resources_->tensors.emplace(var_name, result);
resources_->weights.push_back(var_name);
}
}
}
......@@ -341,10 +337,9 @@ void Compiler::LowerBody() {
} else if (op_type == "popart_checkpointoutput") {
auto inputs = GetOpInputs(op_desc);
auto outputs = GetOpOutputs(op_desc);
PushNameScope(op_desc);
NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = builder_->checkpointOutput(inputs);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_custom_op") {
auto inputs = GetOpInputs(op_desc);
......@@ -359,12 +354,11 @@ void Compiler::LowerBody() {
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
VLOG(10) << "Build graph from custom op: " << __op_type;
auto it = custom_ops_.find(__op_type);
PushNameScope(op_desc);
NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids =
builder_->customOp(it->second.popart_op, it->second.popart_op.version,
inputs, outputs.size(), attributes, debug_context);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else if (op_type == "popart_printtensor") {
auto inputs = GetOpInputs(op_desc);
......@@ -373,11 +367,10 @@ void Compiler::LowerBody() {
auto print_gradient =
BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title"));
PushNameScope(op_desc);
NameScopeHelper ns_helper(op_desc, builder_.get());
auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
inputs, print_gradient, debug_context, title);
PopNameScope(op_desc);
SetIpuIndexStage(output_ids, op_desc);
PostLower(output_ids, op_desc);
InsertTensors(outputs, output_ids);
} else {
auto itr = name_function_.find(op_type);
......@@ -625,12 +618,13 @@ void Compiler::InsertTensors(const std::vector<std::string>& output_names,
resources_->tensors.emplace(output_names[0], tensor_id);
}
void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage";
void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) {
// Set pipline
// Due to the limitation of popart, if an op has multiple outputs,
// pipline settings needs to be set at the same time
auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_ids_set, ipu_index);
......@@ -639,18 +633,24 @@ void Compiler::SetIpuIndexStage(const std::vector<std::string>& tensor_ids,
if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_ids_set, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage
VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
for (auto& tensor_id : tensor_ids) {
PostLower(tensor_id, op_desc, true);
}
}
void Compiler::SetIpuIndexStage(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetIpuIndexStage";
void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc) {
PostLower(tensor_id, op_desc, false);
}
if (op_desc->HasAttr(sIpuIndexAttr)) {
void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc,
bool skip_pipline) {
// Set pipline
if (!skip_pipline && op_desc->HasAttr(sIpuIndexAttr)) {
auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
builder_->virtualGraph(tensor_id, ipu_index);
VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
......@@ -658,32 +658,18 @@ void Compiler::SetIpuIndexStage(const std::string& tensor_id,
if (op_desc->HasAttr(sIpuStageAttr)) {
auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
builder_->pipelineStage(tensor_id, ipu_stage);
VLOG(10) << "set " << sIpuStageAttr << "= " << ipu_stage
VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
<< " for op: " << op_desc->Type();
}
}
VLOG(10) << "leave Compiler::SetIpuIndexStage";
}
void Compiler::SetAMPAttributes(const std::vector<std::string>& tensor_ids,
const OpDesc* op_desc) {
if (op_desc->Type() == "popart_matmul") {
for (const auto& tensor_id : tensor_ids) {
SetAMPAttributes(tensor_id, op_desc);
}
}
}
void Compiler::SetAMPAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetAMPAttributes";
// Set amp
if (op_desc->Type() == "popart_matmul") {
if (set_amp_for_all_) {
auto amp = ipu_strategy_->available_memory_proportion;
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 <= "
"amp <= 1",
"AvailableMemoryProportion %f is invalid, which should be in "
"range [0.0, 1.0]",
amp));
}
if (amp > 0.0f) {
......@@ -694,8 +680,8 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute));
if (amp < 0.0f || amp > 1.0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AvailableMemoryProportion %f is invalid, which should be set 0 "
"<= amp <= 1",
"AvailableMemoryProportion %f is invalid, which should be in "
"range [0.0, 1.0]",
amp));
}
if (amp > 0.0f) {
......@@ -705,17 +691,7 @@ void Compiler::SetAMPAttributes(const std::string& tensor_id,
}
}
}
}
VLOG(10) << "leave Compiler::SetAMPAttributes";
}
void Compiler::SetSerializeAttributes(
const std::vector<std::string>& tensor_ids, const OpDesc* op_desc) {
VLOG(10) << "enter Compiler::SetSerializeAttributes";
auto tensor_ids_set =
std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
if (op_desc->Type() == "popart_matmul") {
// Set serialize matmul
if (op_desc->HasAttr(sMatmulSerializeFactor)) {
auto factor =
BOOST_GET_CONST(int, op_desc->GetAttr(sMatmulSerializeFactor));
......@@ -724,16 +700,9 @@ void Compiler::SetSerializeAttributes(
mode = BOOST_GET_CONST(std::string,
op_desc->GetAttr(sMatmulSerializeMode));
}
builder_->setSerializeMatMul(tensor_ids_set, mode, (int64_t)factor, true);
builder_->setSerializeMatMul({tensor_id}, mode, factor, true);
}
}
VLOG(10) << "leave Compiler::SetSerializeAttributes";
}
void Compiler::SetSerializeAttributes(const std::string& tensor_id,
const OpDesc* op_desc) {
std::vector<std::string> tensor_ids = {tensor_id};
SetSerializeAttributes(tensor_ids, op_desc);
}
void Compiler::SetCustomOps(
......@@ -793,29 +762,6 @@ popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) {
return popart::DebugContext(op_identify_id);
}
void Compiler::PushNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
if (!op_namescope.empty()) {
op_namescope.pop_back();
}
if (!op_namescope.empty()) {
op_namescope.erase(op_namescope.begin());
}
VLOG(10) << "name_scope is: " << op_namescope;
builder_->pushNameScope(op_namescope);
}
void Compiler::PopNameScope(const OpDesc* op) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope == "/") {
return;
}
builder_->popNameScope();
}
} // namespace ipu
} // namespace platform
} // namespace paddle
......@@ -70,7 +70,7 @@ struct CompilerResources {
std::unique_ptr<popart::Optimizer> optimizer;
};
// helper for lowering graph
// Helper for lowering graph
struct GraphHelper {
explicit GraphHelper(const Graph *);
......@@ -81,6 +81,30 @@ struct GraphHelper {
std::vector<int> sorted_vars_id;
};
// Helper for adding namescope info
struct NameScopeHelper {
NameScopeHelper(const OpDesc *op, popart::Builder *builder)
: builder_(builder) {
auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
if (op_namescope.empty() || op_namescope == "/") {
return;
}
op_namescope.pop_back();
op_namescope.erase(op_namescope.begin());
builder->pushNameScope(op_namescope);
pushed_ = true;
}
~NameScopeHelper() {
if (pushed_) {
builder_->popNameScope();
}
}
bool pushed_ = false;
popart::Builder *builder_;
};
class Compiler {
public:
Compiler();
......@@ -119,18 +143,9 @@ class Compiler {
const std::vector<std::string> &tensor_ids);
void InsertTensors(const std::vector<std::string> &output_names,
const std::string &tensor_id);
void SetIpuIndexStage(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc);
void SetIpuIndexStage(const std::string &tensor_id, const OpDesc *op_desc);
void SetAMPAttributes(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc);
void SetAMPAttributes(const std::string &tensor_id, const OpDesc *op_desc);
void SetSerializeAttributes(const std::vector<std::string> &tensor_ids,
const OpDesc *op_desc);
void SetSerializeAttributes(const std::string &tensor_id,
const OpDesc *op_desc);
void PushNameScope(const OpDesc *op);
void PopNameScope(const OpDesc *op);
void PostLower(const std::vector<std::string> &, const OpDesc *);
void PostLower(const std::string &, const OpDesc *);
void PostLower(const std::string &, const OpDesc *, bool);
private:
std::unique_ptr<popart::Builder> builder_;
......
......@@ -20,6 +20,40 @@ namespace paddle {
namespace platform {
namespace ipu {
// Get paddle prefix and popart postfix of weight states
// Format: {popart_postfix, paddle_prefix}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string &opt_type) {
std::vector<std::pair<std::string, std::string>> pre_post_fix;
// Weight self
pre_post_fix.push_back(std::make_pair("", ""));
// Weight states
// TODO(alleng) support pair("Accl1___", "_moment1_{id!=0}")
if (opt_type == "adam" || opt_type == "lamb" || opt_type == "adamw") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "momentum") {
pre_post_fix.push_back(std::make_pair("Accl___", "_velocity_0"));
} else if (opt_type == "adamax") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_inf_norm__0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "adagrad") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment_0"));
} else if (opt_type == "adadelta") {
pre_post_fix.push_back(std::make_pair("Accl1___", "__avg_squared_grad_0"));
pre_post_fix.push_back(
std::make_pair("Accl2___", "__avg_squared_update_0"));
} else if (opt_type == "rmsprop") {
pre_post_fix.push_back(std::make_pair("Accl1___", "_mean_square_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_mean_grad_0"));
pre_post_fix.push_back(std::make_pair("Accl3___", "_momentum__0"));
}
return pre_post_fix;
}
Executor::~Executor() {
Detach();
session_.reset();
......
......@@ -412,6 +412,15 @@ IpuStrategy::IpuStrategy() {
RegisterGetter(map_options_getter, options_type, "gcl_options", "map",
[&]() { return popart_options.gclOptions; });
// Default options
// Can also be set as a custom logger in python, like using tqdm
popart_options.compilationProgressLogger = [](int progress, int total) {
if (progress % 10 == 0) {
VLOG(1) << "compile progress: " << progress << "%";
}
};
}
void IpuStrategy::AddBoolOption(const std::string& option, bool value) {
......@@ -513,6 +522,11 @@ void IpuStrategy::AddCustomOp(const std::string& paddle_op,
IpuCustomOpIdentifier(paddle_op, popart_op, domain, version));
}
void IpuStrategy::SetCompilationProgressLogger(
const std::function<void(int, int)>& logger) {
popart_options.compilationProgressLogger = logger;
}
std::string IpuStrategy::GetOption(const std::string& option) {
return get(option, options_getter);
}
......
......@@ -125,6 +125,8 @@ class IpuStrategy {
const std::vector<int> &values);
void AddCustomOp(const std::string &paddle_op, const std::string &popart_op,
const std::string &domain, int version);
void SetCompilationProgressLogger(
const std::function<void(int, int)> &logger);
std::string GetOption(const std::string &);
std::vector<std::string> GetVectorOption(const std::string &);
......
......@@ -184,27 +184,6 @@ bool GetBoolEnv(std::string str) {
}
}
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string& opt_type) {
// format: {popart_tensor_id, paddle_tensor_id}, ...
std::vector<std::pair<std::string, std::string>> pre_post_fix;
if (opt_type == "adam" || opt_type == "lamb") {
pre_post_fix.push_back(std::make_pair("", ""));
pre_post_fix.push_back(std::make_pair("Accl1___", "_moment1_0"));
pre_post_fix.push_back(std::make_pair("Accl2___", "_moment2_0"));
pre_post_fix.push_back(std::make_pair("Step___", "_beta1_pow_acc_0"));
} else if (opt_type == "sgd" || opt_type == "momentum") {
// sgd
pre_post_fix.push_back(std::make_pair("", ""));
} else {
pre_post_fix.push_back(std::make_pair("", ""));
//
}
return pre_post_fix;
}
int RequestIpus(const int num_ipus) {
// num_ipus must be pow(2, n);
return std::pow(2, ceil(log2(num_ipus)));
......
......@@ -229,9 +229,6 @@ struct ConstantOpAttrVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { RaiseError(); }
};
std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
const std::string& opt_type);
int RequestIpus(const int num_ipus);
} // namespace ipu
......
......@@ -4357,7 +4357,10 @@ All parameter, weight, gradient are variables in Paddle.
for (auto element : opt) {
auto option_name = element.first.cast<std::string>();
VLOG(10) << "Set option: " << option_name;
if (py::isinstance<py::bool_>(element.second)) {
if (option_name == "compilation_progress_logger") {
self.SetCompilationProgressLogger(
element.second.cast<py::function>());
} else if (py::isinstance<py::bool_>(element.second)) {
self.AddBoolOption(option_name, element.second.cast<bool>());
} else if (py::isinstance<py::float_>(element.second)) {
self.AddDoubleOption(option_name,
......
......@@ -11,4 +11,5 @@ if(WITH_IPU)
set_tests_properties(test_conv_op_ipu PROPERTIES TIMEOUT 300)
set_tests_properties(test_elemetwise_x_op_ipu PROPERTIES TIMEOUT 300)
set_tests_properties(test_reduce_x_op_ipu PROPERTIES TIMEOUT 600)
set_tests_properties(test_save_load_ipu PROPERTIES TIMEOUT 600)
endif()
......@@ -73,10 +73,15 @@ class TestIpuStrategy(unittest.TestCase):
'autoReport.directory': 'path',
'autoReport.all': 'true'
}
options['random_seed'] = 1234
for k, v in options.items():
ipu_strategy.set_options({k: v})
assert v == ipu_strategy.get_option(k), f"set {k} to {v} failed "
# The custom logger need 2 int as inputs
logger = lambda progress, total: print(f"compile progrss: {progress}/{total}")
ipu_strategy.set_options({'compilation_progress_logger': logger})
if __name__ == "__main__":
unittest.main()
......@@ -14,9 +14,11 @@
import tempfile
import unittest
from functools import partial
import numpy as np
import paddle
import paddle.optimizer
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
......@@ -28,7 +30,8 @@ class TestBase(IPUOpTest):
self.set_atol()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
self.set_attrs()
self.set_optimizer()
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
......@@ -39,15 +42,16 @@ class TestBase(IPUOpTest):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
def set_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'sgd'
self.attrs['enable_fp16'] = False
self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.SGD, learning_rate=1e-1)
def _test_base(self, save_otherwise_load):
scope = paddle.static.Scope()
main_prog = paddle.static.Program()
......@@ -71,16 +75,8 @@ class TestBase(IPUOpTest):
name='conv2d')
loss = paddle.mean(conv1)
if self.attrs['is_training']:
if self.attrs['opt_type'] == 'sgd':
sgd = paddle.optimizer.SGD(learning_rate=1e-2)
sgd.minimize(loss)
elif self.attrs['opt_type'] == 'adam':
adam = paddle.optimizer.Adam(learning_rate=1e-2)
adam.minimize(loss)
elif self.attrs['opt_type'] == 'lamb':
lamb = paddle.optimizer.Lamb(learning_rate=1e-2)
lamb.minimize(loss)
# apply optimizer
self.optimizer().minimize(loss)
fetch_list = [loss.name]
place = paddle.IPUPlace()
......@@ -91,8 +87,7 @@ class TestBase(IPUOpTest):
paddle.static.load(main_prog, self.attrs['model_path'].name)
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(
is_training=self.attrs['is_training'])
ipu_strategy.set_graph_config(is_training=True)
ipu_strategy.set_precision_config(
enable_fp16=self.attrs['enable_fp16'])
ipu_program = paddle.static.IpuCompiledProgram(
......@@ -131,62 +126,109 @@ class TestBase(IPUOpTest):
self.attrs['model_path'].cleanup()
class TestMomentum(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Momentum, learning_rate=1e-1)
class TestAdam(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'adam'
self.attrs['enable_fp16'] = False
self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adam, learning_rate=1e-1)
class TestLamb(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'lamb'
self.attrs['enable_fp16'] = False
self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Lamb, learning_rate=1e-1)
class TestAdamW(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.AdamW, learning_rate=1e-1)
class TestAdamax(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adamax, learning_rate=1e-1)
class TestAdagrad(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestAdadelta(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestRMSProp(TestBase):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.RMSProp, learning_rate=1e-1)
class TestCenteredRMSProp(TestBase):
def set_optimizer(self):
self.optimizer = partial(
paddle.optimizer.RMSProp, learning_rate=1e-1, centered=True)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestSGDFP16(TestBase):
def set_op_attrs(self):
def set_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'sgd'
self.attrs['enable_fp16'] = True
self.attrs['model_path'] = tempfile.TemporaryDirectory()
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.SGD, learning_rate=1e-1)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestAdamFP16(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'adam'
self.attrs['enable_fp16'] = True
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestMomentumFp16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Momentum, learning_rate=1e-1)
@unittest.skipIf(IPUOpTest.use_ipumodel(), "skip for ipumodel")
class TestLambFP16(TestBase):
def set_op_attrs(self):
self.attrs = {}
self.attrs['steps'] = 100
self.attrs['save_at_step'] = 20
self.attrs['is_training'] = True
self.attrs['opt_type'] = 'lamb'
self.attrs['enable_fp16'] = True
self.attrs['model_path'] = tempfile.TemporaryDirectory()
class TestAdamFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adam, learning_rate=1e-1)
class TestLambFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Lamb, learning_rate=1e-1)
class TestAdamWFP16FP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.AdamW, learning_rate=1e-1)
class TestAdamaxFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adamax, learning_rate=1e-1)
class TestAdagradFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestAdadeltaFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.Adagrad, learning_rate=1e-1)
class TestRMSPropFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(paddle.optimizer.RMSProp, learning_rate=1e-1)
class TestCenteredRMSPropFP16(TestSGDFP16):
def set_optimizer(self):
self.optimizer = partial(
paddle.optimizer.RMSProp, learning_rate=1e-1, centered=True)
if __name__ == "__main__":
......
# A image for building paddle binaries
# build docker image
# docker build -t paddlepaddle/paddle:ipu-dev-2.3.0 -f tools/dockerfile/Dockerfile.ipu .
# docker build -t paddlepaddle/paddle:latest-dev-ipu -f tools/dockerfile/Dockerfile.ipu .
# run a container
# docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK --device=/dev/infiniband/ --ipc=host --rm -it paddlepaddle/paddle:ipu-dev-2.3.0 bash
# docker run --ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK --device=/dev/infiniband/ --ipc=host --rm -it paddlepaddle/paddle:latest-dev-ipu bash
FROM graphcore/poplar:2.3.0
MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册