未验证 提交 e0409c93 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] Update IpuStrategy Python Part (#39646)

* Update IpuStrategy Python Part

* add docs

* add add_custom_op for ipu_strategy

* fix build warning

* rm unneeded part

* clean api

* fix typo

* update option names

* update IpuStrategy doc
上级 1255e7d6
......@@ -3786,86 +3786,142 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_IPU
py::class_<platform::ipu::IpuBackend,
std::shared_ptr<platform::ipu::IpuBackend>>(m, "IpuBackend")
.def(py::init(&platform::ipu::IpuBackend::GetNewInstance))
.def("clear", &platform::ipu::IpuBackend::Clear)
.def("set_scope", &platform::ipu::IpuBackend::SetScope)
.def("set_ipu_strategy", &platform::ipu::IpuBackend::SetIpuStrategy);
py::class_<platform::ipu::IpuStrategy> ipu_strategy(m, "IpuStrategy");
ipu_strategy.def(py::init())
.def_property(
"num_ipus",
[](const platform::ipu::IpuStrategy &self) { return self.num_ipus; },
[](platform::ipu::IpuStrategy &self, int num_ipus) {
self.num_ipus = num_ipus;
})
.def_property(
"accumulationFactor",
[](const platform::ipu::IpuStrategy &self) {
return self.popart_options_.accumulationFactor;
std::unique_ptr<platform::ipu::IpuBackend, py::nodelete>>(
m, "IpuBackend")
// manage IpuBackend in C++
.def("get_instance",
[]() {
return std::unique_ptr<platform::ipu::IpuBackend, py::nodelete>(
platform::ipu::IpuBackend::GetInstance());
},
[](platform::ipu::IpuStrategy &self, int accumulationFactor) {
self.popart_options_.accumulationFactor = accumulationFactor;
})
.def_property("batches_per_step",
[](const platform::ipu::IpuStrategy &self) {
return self.batches_per_step;
},
[](platform::ipu::IpuStrategy &self, int batches_per_step) {
self.batches_per_step = batches_per_step;
})
.def_property("is_training",
[](const platform::ipu::IpuStrategy &self) {
return self.is_training;
},
[](platform::ipu::IpuStrategy &self, bool is_training) {
self.is_training = is_training;
})
.def_property(
"enable_pipelining",
[](const platform::ipu::IpuStrategy &self) {
return self.popart_options_.enablePipelining;
},
[](platform::ipu::IpuStrategy &self, bool enable_pipelining) {
self.popart_options_.enablePipelining = enable_pipelining;
})
.def_property(
"enable_manual_shard",
[](const platform::ipu::IpuStrategy &self) {
return self.popart_options_.virtualGraphMode ==
platform::ipu::VirtualGraphMode::Manual;
},
[](platform::ipu::IpuStrategy &self, bool enable_ipu_shard) {
if (enable_ipu_shard) {
self.popart_options_.virtualGraphMode =
platform::ipu::VirtualGraphMode::Manual;
py::return_value_policy::reference)
.def("detach", &platform::ipu::IpuBackend::Detach)
.def("reset", &platform::ipu::IpuBackend::Reset)
.def("set_scope", &platform::ipu::IpuBackend::SetScope)
.def("set_ipu_strategy", &platform::ipu::IpuBackend::SetIpuStrategy)
.def("save_model_proto", &platform::ipu::IpuBackend::SaveModelProto);
py::class_<platform::ipu::IpuStrategy>(m, "IpuStrategy")
.def(py::init())
.def("set_options",
[](platform::ipu::IpuStrategy &self, const py::dict &opt) {
for (auto element : opt) {
auto option_name = element.first.cast<std::string>();
VLOG(10) << "Set option: " << option_name;
if (py::isinstance<py::bool_>(element.second)) {
self.AddBoolOption(option_name, element.second.cast<bool>());
} else if (py::isinstance<py::float_>(element.second)) {
self.AddDoubleOption(option_name,
element.second.cast<double>());
} else if (py::isinstance<py::int_>(element.second)) {
self.AddUint64Option(option_name,
element.second.cast<std::uint64_t>());
} else if (py::isinstance<py::str>(element.second)) {
self.AddStringOption(option_name,
element.second.cast<std::string>());
} else if (py::isinstance<py::set>(element.second) ||
py::isinstance<py::list>(element.second)) {
for (auto option : element.second.cast<py::list>()) {
std::string option_val;
if (py::isinstance<py::str>(option)) {
option_val = option.cast<std::string>();
} else if (py::isinstance<py::int_>(option)) {
option_val = std::to_string(option.cast<std::uint64_t>());
} else {
self.popart_options_.virtualGraphMode =
platform::ipu::VirtualGraphMode::Off;
PADDLE_THROW(platform::errors::Unimplemented(
"Failed to convert type: %s when set IpuStrategy "
"option: %s",
option.get_type(), option_name));
}
self.InsertStringOption(option_name, option_val);
}
} else if (py::isinstance<py::dict>(element.second)) {
if (option_name.rfind("location_", 0) == 0) {
for (auto option : element.second.cast<py::dict>()) {
self.SetTensorLocation(
option_name, option.first.cast<std::string>(),
option.second.cast<std::uint64_t>());
}
} else if (option_name == "custom_op") {
std::string paddle_op;
std::string popart_op;
std::string domain;
int version = -1;
for (auto option : element.second.cast<py::dict>()) {
std::string option_key = option.first.cast<std::string>();
if (option_key == "paddle_op") {
paddle_op = option.second.cast<std::string>();
} else if (option_key == "popart_op") {
popart_op = option.second.cast<std::string>();
} else if (option_key == "domain") {
domain = option.second.cast<std::string>();
} else if (option_key == "version") {
version = option.second.cast<int>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid argument, key must be one of paddle_op, "
"popart_op, domain or version, but revecived %s",
option_key));
}
}
self.AddCustomOp(paddle_op, popart_op, domain, version);
} else {
for (auto option : element.second.cast<py::dict>()) {
std::string option_key = option.first.cast<std::string>();
std::string option_val;
if (py::isinstance<py::str>(option.second)) {
option_val = option.second.cast<std::string>();
} else if (py::isinstance<py::int_>(option.second)) {
option_val =
std::to_string(option.second.cast<std::uint64_t>());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Failed to convert value type: %s when set "
"IpuStrategy option: %s",
option.second.get_type(), option_key));
}
self.InsertStringPairOption(option_name, option_key,
option_val);
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid IpuStrategy option value type: %s, please check "
"input value for option: %s",
element.second.get_type(), option_name));
}
}
})
.def("get_option",
[](platform::ipu::IpuStrategy &self, const std::string &name) {
py::dict res;
auto option_type = self.GetOptionType(name);
res["name"] = name;
res["type"] = option_type;
if (option_type == "vector") {
auto value = self.GetVectorOption(name);
res["value"] = value;
} else if (option_type == "map") {
auto value = self.GetMapOption(name);
res["value"] = value;
} else {
auto value_s = self.GetOption(name);
res["value_s"] = value_s;
if (option_type == "bool") {
res["value"] = static_cast<bool>(std::stoi(value_s));
} else if (option_type == "uint64") {
res["value"] = std::stoul(value_s);
} else if (option_type == "double") {
res["value"] = std::stod(value_s);
} else if (option_type == "string") {
res["value"] = value_s;
}
}
return res;
})
.def_property("need_avg_shard",
[](const platform::ipu::IpuStrategy &self) {
return self.need_avg_shard;
},
[](platform::ipu::IpuStrategy &self, bool need_avg_shard) {
self.need_avg_shard = need_avg_shard;
})
.def_property("batch_size",
[](const platform::ipu::IpuStrategy &self) {
return self.batch_size;
},
[](platform::ipu::IpuStrategy &self, int batch_size) {
self.batch_size = batch_size;
})
.def_property("enable_fp16",
[](const platform::ipu::IpuStrategy &self) {
return self.enable_fp16;
},
[](platform::ipu::IpuStrategy &self, bool enable_fp16) {
self.enable_fp16 = enable_fp16;
});
.def("enable_pattern", &platform::ipu::IpuStrategy::EnablePattern)
.def("disable_pattern", &platform::ipu::IpuStrategy::DisablePattern)
.def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled);
#endif
BindFleetWrapper(&m);
......
......@@ -350,8 +350,14 @@ void SetTensorFromPyArrayT(
auto type = framework::ToDataType(std::type_index(typeid(T)));
self->ResetHolderWithType(holder, framework::TransToPtenDataType(type));
} else {
// IPU does not store Tensor data, Tensor will be created on CPU
if (!self->initialized()) {
auto dst = self->mutable_data<T>(place);
std::memcpy(dst, array.data(), array.nbytes());
} else {
auto dst = self->mutable_data<T>(self->place());
std::memcpy(dst, array.data(), array.nbytes());
}
}
#else
PADDLE_THROW(platform::errors::PermissionDenied(
......
......@@ -502,9 +502,6 @@ class IpuStrategy(object):
"""
Help users precisely control the graph building in :code:`paddle.static.IpuCompiledProgram` .
Args:
None.
Returns:
The IpuStrategy instance.
......@@ -517,23 +514,36 @@ class IpuStrategy(object):
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
"""
def __init__(self):
if core.is_compiled_with_ipu():
self._ipu_strategy = core.IpuStrategy()
default_options = {
'location_optimizer': {
'on_chip': 0,
'use_replicated_tensor_sharding': 1,
}, # set optimizer location
'accumulation_and_replication_reduction_type':
1, # popart::ReductionType::Mean
'mean_accumulation_and_replication_reduction_strategy':
1, # popart::MeanReductionStrategy::Post
}
self._ipu_strategy.set_options(default_options)
self.has_custom_ops = False
self.custom_op_names = []
else:
raise RuntimeError(
"Can not use IpuStrategy in non IPU compiled environment, please re-compile with WITH_IPU=ON."
)
def SetGraphConfig(self,
def set_graph_config(self,
num_ipus=1,
is_training=True,
batch_size=1,
enable_manual_shard=False,
need_avg_shard=False):
enable_manual_shard=False):
"""
Set graph configuration to the IpuStrategy instance.
......@@ -544,8 +554,6 @@ class IpuStrategy(object):
if the batch-size in the graph is dynamic. Default 1, which means the batch-size would be set 1, if the batch-size is dynamice.
enable_manual_shard (bool, optional): Enable graph sharding or not. Only if num_ipus > 1, enable_manual_shard is able to be set True.
Default False, which means disabled.
need_avg_shard (bool, optional): Enable auto graph sharding or not. Only if num_ipus > 1 and enable_manual_shard=True, need_avg_shard is able to be set Trues.
Default False, which means disabled.
Returns:
None.
......@@ -559,32 +567,29 @@ class IpuStrategy(object):
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.SetGraphConfig(num_ipus=1,
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
batch_size=1,
enable_manual_shard=False,
need_avg_shard=False)
enable_manual_shard=False)
"""
self._ipu_strategy.num_ipus = num_ipus
self._ipu_strategy.is_training = is_training
self._ipu_strategy.batch_size = batch_size
self._ipu_strategy.enable_manual_shard = enable_manual_shard
if self._ipu_strategy.num_ipus == 1 and self._ipu_strategy.enable_manual_shard:
if num_ipus == 1 and enable_manual_shard:
raise RuntimeError(
"Only if num_ipus > 1, enable_manual_shard is able to be set True."
)
self._ipu_strategy.need_avg_shard = need_avg_shard
if self._ipu_strategy.enable_manual_shard != True and self._ipu_strategy.need_avg_shard:
raise RuntimeError(
"Only if enable_manual_shard=True, need_avg_shard is able to be set True."
)
def SetPipeliningConfig(self,
options = {
'num_ipus': num_ipus,
'is_training': is_training,
'micro_batch_size': batch_size,
'enable_manual_shard': enable_manual_shard,
}
self.set_options(options)
def set_pipelining_config(self,
enable_pipelining=False,
batches_per_step=1,
accumulationFactor=1):
accumulation_factor=1):
"""
Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance.
......@@ -593,7 +598,7 @@ class IpuStrategy(object):
Default False, which means disabled.
batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1.
Default 1, which means no data pipelining.
accumulationFactor (int, optional): Specify the number of micro-batches to accumulate
accumulation_factor (int, optional): Specify the number of micro-batches to accumulate
before applying the varUpdate. Default 1, which means disable the accumulation.
Returns:
......@@ -610,23 +615,23 @@ class IpuStrategy(object):
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.SetPipeliningConfig(enable_pipelining=False,
ipu_strategy.set_pipelining_config(enable_pipelining=False,
batches_per_step=1,
accumulationFactor=1)
accumulation_factor=1)
"""
self._ipu_strategy.enable_pipelining = enable_pipelining
if self._ipu_strategy.enable_manual_shard != True and self._ipu_strategy.enable_pipelining:
enable_manual_shard = self.get_option('enable_manual_shard')
if not enable_manual_shard and enable_pipelining:
raise RuntimeError(
"Only if enable_manual_shard=True, enable_pipelining is able to be set True."
)
self._ipu_strategy.batches_per_step = batches_per_step
if self._ipu_strategy.enable_pipelining != True and self._ipu_strategy.batches_per_step > 1:
raise RuntimeError(
"Only if enable_pipelining=True, batches_per_step is able to be set > 1."
)
self._ipu_strategy.accumulationFactor = accumulationFactor
options = {
'enable_pipelining': enable_pipelining,
'batches_per_step': batches_per_step,
'accumulation_factor': accumulation_factor,
}
self.set_options(options)
def SetHalfConfig(self, enable_fp16=False):
def set_precision_config(self, enable_fp16=False):
"""
Set half computation configuration to the IpuStrategy instance. Used to optimize the performance.
......@@ -647,73 +652,135 @@ class IpuStrategy(object):
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.SetHalfConfig(enable_fp16=False)
ipu_strategy.set_precision_config(enable_fp16=False)
"""
options = {'enable_fp16': enable_fp16, }
self.set_options(options)
self._ipu_strategy.enable_fp16 = enable_fp16
@property
def num_ipus(self):
"""
Get the number of IPU devices from IpuStrategy instance.
def add_custom_op(self,
paddle_op,
popart_op=None,
domain='custom.ops',
version=1):
"""
return self._ipu_strategy.num_ipus
Add a mapping to use popart custom ops running on the IPU.
@property
def is_training(self):
"""
Get the boolean of training or inference from IpuStrategy instance.
"""
return self._ipu_strategy.is_training
Args:
paddle_op(str): the name of custom op in paddle.
@property
def batch_size(self):
"""
Get the batch_size used in dynamic batch_size graph from IpuStrategy instance.
popart_op(str): the name of custom op in popart.
domain(str): domain name of custom op in popart.
version(int): version of custom op in popart.
Returns:
None.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
ipu_strategy.add_custom_op('paddle_relu', 'popart_relu')
"""
return self._ipu_strategy.batch_size
if popart_op is None:
popart_op = paddle_op
custom_op = {
'paddle_op': paddle_op,
'popart_op': popart_op,
'domain': domain,
'version': version,
}
self.set_options({'custom_op': custom_op})
self.custom_op_names.append(paddle_op)
if not self.has_custom_ops:
self.has_custom_ops = True
@property
def enable_manual_shard(self):
def set_options(self, options):
"""
Get the boolean of enable manual shard or not from IpuStrategy instance.
Set options from dict.
Args:
options(dict): dict of options.
Returns:
None.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
options = {'num_ipus':1, 'enable_fp16': True}
ipu_strategy.set_options(options)
"""
return self._ipu_strategy.enable_manual_shard
self._ipu_strategy.set_options(options)
@property
def need_avg_shard(self):
def get_option(self, option):
"""
Get the boolean of need average shard or not from IpuStrategy instance.
Get option.
Args:
option(str): name of option.
Returns:
option value.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
paddle.enable_static()
ipu_strategy = static.IpuStrategy()
num_ipus = ipu_strategy.get_option('num_ipus')
"""
return self._ipu_strategy.need_avg_shard
return self._ipu_strategy.get_option(option)['value']
@property
def enable_pipelining(self):
def num_ipus(self):
"""
Get the boolean of enable pipelining or not from IpuStrategy instance.
Get the number of IPU devices from IpuStrategy instance.
"""
return self._ipu_strategy.enable_pipelining
return self.get_option('num_ipus')
@property
def batches_per_step(self):
def is_training(self):
"""
Get the number of batch_size per run in the pipelining mode from IpuStrategy instance.
Get the boolean of training or inference from IpuStrategy instance.
"""
return self._ipu_strategy.batches_per_step
return self.get_option('is_training')
@property
def accumulationFactor(self):
def enable_pipelining(self):
"""
Get the number of micro-batches to accumulate before applying the varUpdate from IpuStrategy instance.
Get the boolean of enable pipelining or not from IpuStrategy instance.
"""
return self._ipu_strategy.accumulationFactor
return self.get_option('enable_pipelining')
@property
def enable_fp16(self):
"""
Get the boolean of float16 mode or not from IpuStrategy instance.
"""
return self._ipu_strategy.enable_fp16
return self.get_option('enable_fp16')
class IpuCompiledProgram(object):
......@@ -750,9 +817,9 @@ class IpuCompiledProgram(object):
main_prog = static.default_main_program()
ipu_strategy = static.IpuStrategy()
ipu_strategy.SetGraphConfig(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.SetPipeliningConfig(enable_pipelining=False, batches_per_step=1, accumulationFactor=1)
ipu_strategy.SetHalfConfig(enable_fp16=False)
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1)
ipu_strategy.set_precision_config(enable_fp16=False)
ipu_compiled_program = static.IpuCompiledProgram(
main_prog,
......@@ -766,14 +833,12 @@ class IpuCompiledProgram(object):
)
if program is None:
program = default_main_program()
program = framework.default_main_program()
if not isinstance(program, framework.Program):
raise TypeError(
"The type of program is wrong, expected Program, but got %s" %
type(program))
# import here to avoiding confused
import paddle
self._program = program
self._compiled = False
......@@ -781,23 +846,21 @@ class IpuCompiledProgram(object):
if scope is not None:
self._scope = scope
else:
# import here to avoiding confused
import paddle
self._scope = paddle.static.global_scope()
if ipu_strategy is not None:
self._ipu_strategy = ipu_strategy._ipu_strategy
self._ipu_strategy = ipu_strategy
else:
self._ipu_strategy = core.IpuStrategy()
self._ipu_strategy = IpuStrategy()
self._backend = core.IpuBackend()
self._backend.set_scope(self._scope)
self._backend.set_ipu_strategy(self._ipu_strategy)
self._graph_passes = [
"optimizer_extract_pass", "optimizer_state_align_pass",
"forward_graph_extract_pass", "infer_shape_pass", "avg_shard_pass",
"popart_canonicalization_pass"
]
global ipu_compiler_ref
ipu_compiler_ref = self
if ipu_strategy.has_custom_ops:
self._custom_op_names = set(ipu_strategy.custom_op_names)
else:
self._custom_op_names = ()
self._backend = core.IpuBackend.get_instance()
def compile(self, feed_list, fetch_list):
"""
......@@ -828,20 +891,23 @@ class IpuCompiledProgram(object):
main_prog = static.default_main_program()
ipu_strategy = static.IpuStrategy()
ipu_strategy.SetGraphConfig(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.SetPipeliningConfig(enable_pipelining=False, batches_per_step=1, accumulationFactor=1)
ipu_strategy.SetHalfConfig(enable_fp16=False)
ipu_strategy.set_graph_config(num_ipus=1, is_training=True, batch_size=1)
ipu_strategy.set_pipelining_config(enable_pipelining=False, batches_per_step=1, accumulation_factor=1)
ipu_strategy.set_precision_config(enable_fp16=False)
program = static.IpuCompiledProgram(
main_prog,
ipu_strategy=ipu_strategy).compile([a.name], [b.name])
"""
self._backend.set_scope(self._scope)
self._backend.set_ipu_strategy(self._ipu_strategy._ipu_strategy)
# feed and fetch doesn't have corresponding popart op, so we rm both here
global_block = self._program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
if op.type == 'feed' or op.type == 'fetch':
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
......@@ -854,26 +920,45 @@ class IpuCompiledProgram(object):
self._program.desc.flush()
self._graph = core.Graph(self._program.desc)
for pass_name in self._graph_passes:
graph_pass = core.get_pass(pass_name)
if pass_name == "infer_shape_pass":
graph_pass.set("feed_list", feed_list)
graph_pass.apply(self._graph)
ipu_inplace_pass = core.get_pass("ipu_inplace_pass")
ipu_inplace_pass.set("feed_list", feed_list)
ipu_inplace_pass.set("fetch_list", fetch_list)
ipu_inplace_pass.apply(self._graph)
ipu_graph_builder_pass = core.get_pass("ipu_graph_builder_pass")
ipu_graph_builder_pass.set("feed_list", feed_list)
ipu_graph_builder_pass.set("fetch_list", fetch_list)
ipu_graph_builder_pass.apply(self._graph)
ipu_runtime_replacer_pass = core.get_pass("ipu_runtime_replacer_pass")
ipu_runtime_replacer_pass.set("feed_list", feed_list)
ipu_runtime_replacer_pass.set("fetch_list", fetch_list)
ipu_runtime_replacer_pass.apply(self._graph)
if self._ipu_strategy.is_training:
passes = [
'optimizer_extract_pass',
'optimizer_state_align_pass',
]
for pass_name in passes:
a_pass = core.get_pass(pass_name)
a_pass.apply(self._graph)
passes = [
'forward_graph_extract_pass',
'infer_shape_pass',
'avg_shard_pass',
'delete_scale_op_pass',
]
for pass_name in passes:
a_pass = core.get_pass(pass_name)
if pass_name == 'infer_shape_pass':
a_pass.set('feed_list', feed_list)
a_pass.apply(self._graph)
a_pass = core.get_pass('popart_canonicalization_pass')
if self._custom_op_names:
a_pass.set('custom_ops', self._custom_op_names)
a_pass.apply(self._graph)
a_pass = core.get_pass("transfer_cast_op_pass")
a_pass.apply(self._graph)
passes = [
'ipu_inplace_pass',
'ipu_graph_builder_pass',
'ipu_runtime_replacer_pass',
]
for pass_name in passes:
a_pass = core.get_pass(pass_name)
a_pass.set('feed_list', feed_list)
a_pass.set('fetch_list', fetch_list)
a_pass.apply(self._graph)
convert_pass = core.get_pass('graph_to_program_pass')
desc = core.ProgramDesc()
......@@ -904,9 +989,3 @@ class IpuCompiledProgram(object):
program.org_program = self._program
return program
def clean(self):
self._backend.clear()
def __del__(self):
self.clean()
......@@ -1583,9 +1583,6 @@ class Executor(object):
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
if core.is_compiled_with_ipu():
if hasattr(program.lr_sheduler, 'lr_var'):
lr_var = program.lr_sheduler.lr_var
data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope, lr_sheduler._var_name)
tensor.set(data, self.place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册