diff --git a/doc/design/optimizer.md b/doc/design/optimizer.md index 202b4b65103c0b7c536a9cb466c4120ce134d8c3..691081c268b848811bf5ee6d6a41edfe0f47eec0 100644 --- a/doc/design/optimizer.md +++ b/doc/design/optimizer.md @@ -79,7 +79,7 @@ class Optimizer(object): def minimize(self, loss, parameter_list): """Add operations to minimize `loss` by updating `parameter_list`. - This method combines interface `append_backward_ops()` and + This method combines interface `append_backward()` and `create_optimization_pass()` into one. """ params_grads = self.create_backward_pass(loss, parameter_list) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index b361e64438251c1df827667fb825e7f5909fb09e..781bbb4c19f1c610df485c3061ca8b510e727019 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -88,6 +88,14 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs, need_update_ = true; } +void OpDesc::CopyFrom(const OpDesc &op_desc) { + desc_.set_type(op_desc.Type()); + inputs_ = op_desc.inputs_; + outputs_ = op_desc.outputs_; + attrs_ = op_desc.attrs_; + need_update_ = true; +} + OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog) : desc_(desc), need_update_(false) { // restore inputs_ diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 93d4a88f3c390551ab41e42ec2f6f30f52e306db..4cf784a0d0d319d09caa27b4e2b589bd7ac4f324 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -35,6 +35,8 @@ class OpDesc { OpDesc(const proto::OpDesc &desc, ProgramDesc *prog); + void CopyFrom(const OpDesc &op_desc); + proto::OpDesc *Proto(); std::string Type() const { return desc_.type(); } diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index bd8973eeb369aabd2c52d4fccf799657c564ee78..7d002b9ea0b597730685ee03b021c4982f787f49 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -74,7 +74,7 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().tensor(); default: - PADDLE_THROW("Unexpected branch."); + PADDLE_THROW("The type of var '", this->Name(), "' is unsupported."); } } diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2b35e4532a9c9f72f473020d472244234af24248..d4f12f0a106e077ac31aa37f46857b74e1e99b59 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -302,8 +302,29 @@ void set_constant(const platform::DeviceContext& context, #endif } +template +struct RowwiseAdd { + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(vector); + auto out = framework::EigenMatrix::From(*output); + + for (int64_t i = 0; i < in_dims[0]; ++i) { + out.chip(i, 0) = in.chip(i, 0) + vec; + } + } +}; + template struct RowwiseAdd; template struct RowwiseAdd; + template struct ColwiseSum; template struct ColwiseSum; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 927838a0948d2df5701b8e9189f59cdd66396b52..d47a7f818ded61baf31e46ea3b8ae3101324111f 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -273,6 +273,35 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template +__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width, + int num) { + T tmp = 1.0 / width; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + int h = i * tmp; + int w = i - h * width; + c[i] = a[i] + b[w]; + } +} + +template +struct RowwiseAdd { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + int blocks = 512; + int grids = (input.numel() + blocks - 1) / blocks; + RowwiseAddKernel<<>>( + input.data(), vector.data(), output->data(), + static_cast(in_dims[1]), static_cast(input.numel())); + } +}; + template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index ddd798dace17012b7d9a949567a90d48067e6b15..de591626df28e2bc3391b609f909612411398247 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -45,25 +45,6 @@ void Transpose::operator()( eigen_out.device(*dev) = eigen_in.shuffle(permute); } -template -void RowwiseAdd::operator()(const DeviceContext& context, - const framework::Tensor& input, - const framework::Tensor& vector, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(vector.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(vector); - auto out = framework::EigenMatrix::From(*output); - Eigen::array shape({{1, static_cast(size)}}); - Eigen::array bcast({{static_cast(in_dims[0]), 1}}); - out.device(*context.eigen_device()) = - in + vec.reshape(shape).broadcast(bcast); -} - template void ColwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index f105370f226e2cceaac685f280d55134d4291028..07292d47e9c165c67fe4a30ee7d851c350beb2e0 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -171,12 +171,23 @@ void BindBlockDesc(py::module &m) { std::string name = byte_name; return self.HasVar(name); }) + .def("has_var_recursive", + [](BlockDesc &self, py::bytes byte_name) { + std::string name = byte_name; + return self.HasVarRecursive(name); + }) .def("find_var", [](BlockDesc &self, py::bytes byte_name) { std::string name = byte_name; return self.FindVar(name); }, py::return_value_policy::reference) + .def("find_var_recursive", + [](BlockDesc &self, py::bytes byte_name) { + std::string name = byte_name; + return self.FindVarRecursive(name); + }, + py::return_value_policy::reference) .def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference) .def("op_size", &BlockDesc::OpSize) .def("op", &BlockDesc::Op, py::return_value_policy::reference) @@ -204,7 +215,7 @@ void BindVarDsec(py::module &m) { .def("set_shape", &VarDesc::SetShape) .def("set_dtype", &VarDesc::SetDataType) .def("shape", &VarDesc::Shape, py::return_value_policy::reference) - .def("dtype", &VarDesc::GetDataType) + .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) .def("lod_level", &VarDesc::GetLodLevel) .def("set_lod_level", &VarDesc::SetLoDLevel) .def("type", &VarDesc::GetType) @@ -236,14 +247,22 @@ void BindOpDesc(py::module &m) { .value("BLOCK", proto::AttrType::BLOCK); py::class_ op_desc(m, "OpDesc", ""); - op_desc.def("type", &OpDesc::Type) + op_desc + .def("__init__", [](OpDesc &self) { new (&self) OpDesc(); }, + py::return_value_policy::reference) + .def("copy_from", &OpDesc::CopyFrom) + .def("type", &OpDesc::Type) .def("set_type", &OpDesc::SetType) .def("input", &OpDesc::Input) .def("input_names", &OpDesc::InputNames) - .def("set_input", &OpDesc::SetInput) .def("output", &OpDesc::Output) .def("output_names", &OpDesc::OutputNames) + .def("set_input", &OpDesc::SetInput) .def("set_output", &OpDesc::SetOutput) + .def("input_arg_names", &OpDesc::InputArgumentNames) + .def("output_arg_names", &OpDesc::OutputArgumentNames) + .def("rename_input", &OpDesc::RenameInput) + .def("rename_output", &OpDesc::RenameOutput) .def("has_attr", &OpDesc::HasAttr) .def("attr_type", &OpDesc::GetAttrType) .def("attr_names", &OpDesc::AttrNames) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 07e38476e68b79f5b3192c619c89cd0e061cc686..04485ce7c1ab87f8655b0e6cbaecc36b3382f647 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -269,23 +269,22 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_grad_op_descs", - [](const OpDesc &op_desc, - const std::unordered_set &no_grad_set, - std::unordered_map &grad_to_var, - const std::vector &grad_sub_block) { - std::vector> grad_op_descs = - framework::OpInfoMap::Instance() - .Get(op_desc.Type()) - .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, - grad_sub_block); - std::vector grad_op_desc_ptrs(grad_op_descs.size()); - std::transform( - grad_op_descs.begin(), grad_op_descs.end(), - grad_op_desc_ptrs.begin(), - [](std::unique_ptr &p) { return p.release(); }); - return grad_op_desc_ptrs; - }); + m.def( + "get_grad_op_desc", [](const OpDesc &op_desc, + const std::unordered_set &no_grad_set, + const std::vector &grad_sub_block) { + std::unordered_map grad_to_var; + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, + grad_sub_block); + std::vector grad_op_desc_ptrs(grad_op_descs.size()); + std::transform(grad_op_descs.begin(), grad_op_descs.end(), + grad_op_desc_ptrs.begin(), + [](std::unique_ptr &p) { return p.release(); }); + return std::make_pair(grad_op_desc_ptrs, grad_to_var); + }); m.def("prune", [](const ProgramDesc &origin, const std::vector> &targets) { ProgramDesc prog_with_targets(origin); @@ -301,6 +300,8 @@ All parameter, weight, gradient are variables in Paddle. InferenceOptimize(*(origin.Proto()), &pruned_desc); return new ProgramDesc(pruned_desc); }); + m.def("empty_var_name", []() { return framework::kEmptyVarName; }); + m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index f188582178f667125ec95cd230100fdb10ce7e88..6966cc75804b6b5a49ceb45a26994c23d2936bdb 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,17 +1,209 @@ from paddle.v2.fluid import framework as framework +from . import core +import collections -__all__ = ['append_backward_ops'] +__all__ = ['append_backward'] -def append_backward_ops(loss, parameter_list=None, no_grad_set=None): +def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, + end_idx=None): + if begin_idx is None: + begin_idx = 0 + if end_idx is None: + end_idx = len(op_desc_list) + for i in range(begin_idx, end_idx): + op_desc = op_desc_list[i] + if isinstance(op_desc, tuple): + op_desc = op_desc[0] + op_desc.rename_input(old_name, new_name) + op_desc.rename_output(old_name, new_name) + + +def _create_op_desc_(op_type, inputs, outputs, attrs): + op_desc = core.OpDesc() + op_desc.set_type(op_type) + for para, args in inputs.iteritems(): + op_desc.set_input(para, args) + for para, args in outputs.iteritems(): + op_desc.set_output(para, args) + for name, val in attrs.iteritems(): + if isinstance(val, framework.Block): + op_desc.set_block_attr(name, val.desc) + else: + op_desc.set_attr(name, val) + return op_desc + + +def _infer_var_data_type_(var_name, block): + grad_var = block.desc.find_var(var_name.encode("ascii")) + fwd_name = _strip_grad_suffix_(var_name.encode("ascii")) + if block.desc.has_var_recursive(fwd_name): + fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) + grad_var.set_dtype(fwd_var.dtype()) + else: + grad_var.set_dtype(core.DataType.FP32) + + +def _all_in_set_(cands, s): + for c in cands: + if not c in s: + return False + return True + + +def _strip_grad_suffix_(name): + pos = name.find(core.grad_var_suffix()) + return name[:pos] if pos != -1 else name + + +def _append_grad_suffix_(name): + return name + core.grad_var_suffix() + + +def _addup_repetitive_outputs_(op_descs): + # In backward part, an variable my be the output of more than one ops. + # In this case, the variable should be the accumulation of all the outputs. + # We adopt adding `sum_op`s to implement the accumulate. + pending_sum_ops = [] + var_rename_count = collections.defaultdict(int) + renamed_vars = collections.defaultdict(list) + for idx, op_desc in enumerate(op_descs): + for var_name in op_desc.input_arg_names(): + if len(renamed_vars[var_name]) > 1: + pending_sum_ops.append( + (_create_op_desc_("sum", {"X": renamed_vars[var_name]}, + {"Out": [var_name]}, {}), idx)) + renamed_vars[var_name] = [var_name] + for var_name in op_desc.output_arg_names(): + if var_name == core.empty_var_name( + ) or var_name in op_desc.input_arg_names(): + # empty variable or inplace op + continue + if len(renamed_vars[var_name]) == 0: + # it's the first time we get the variable + renamed_vars[var_name] = [var_name] + else: + if len(renamed_vars[var_name]) == 1: + new_name = var_name + "@RENAME@" + \ + str(var_rename_count[var_name]) + var_rename_count[var_name] += 1 + # rename original var_name + renamed_vars[var_name][0] = new_name + _rename_arg_(op_descs, var_name, new_name, 0, idx) + _rename_arg_(pending_sum_ops, var_name, new_name) + + new_name = var_name + "@RENAME@" + \ + str(var_rename_count[var_name]) + var_rename_count[var_name] += 1 + op_desc.rename_output(var_name, new_name) + renamed_vars[var_name].append(new_name) + for var_name, inputs in renamed_vars.iteritems(): + if len(inputs) > 1: + pending_sum_ops.append((_create_op_desc_( + "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs))) + # sum_op descs are sorted according to their insert position + for p in reversed(pending_sum_ops): + op_descs.insert(p[1], p[0]) + + return op_descs + + +def _remove_no_grad_branch_(op_descs, no_grad_set): + # Remove ops whose outputs are all in no_grad_dict + op_descs = filter( + lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), + op_descs) + # Insert fill_zeros_like_op + to_insert = [] + for idx, op_desc in enumerate(op_descs): + for arg in op_desc.input_arg_names(): + if core.grad_var_suffix() in arg and arg in no_grad_set: + to_insert.append((_create_op_desc_("fill_zeros_like", { + "X": [_strip_grad_suffix_(arg)] + }, {"Y": [arg]}, {}), idx)) + + map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert)) + + return op_descs + + +def _append_backward_ops_(target, + block, + target_block, + no_grad_dict, + grad_to_var, + callback=None): + grad_op_descs = [] + program = block.program + for op in reversed(block.ops): + grad_sub_block_list = [] + # If the op has its own sub-block, deal with the sub-block first + if op.has_attr("sub_block"): + sub_block = program.block(op.block_attr("sub_block")) + grad_sub_block = program.create_block(parent_idx=sub_block.idx) + _append_backward_ops_(target, sub_block, grad_sub_block, + no_grad_dict, grad_to_var, callback) + grad_sub_block_list.append(grad_sub_block.desc) + + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], grad_sub_block_list) + grad_op_descs.extend(grad_op_desc) + grad_to_var.update(op_grad_to_var) + + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + + grad_op_descs = _remove_no_grad_branch_(grad_op_descs, + no_grad_dict[block.idx]) + + if target_block.idx == 0: + grad_op_descs.insert( + 0, + _create_op_desc_("fill_constant", {}, { + "Out": [_append_grad_suffix_(target.name)] + }, {"shape": [1], + "value": 1.0, + "dtype": target.dtype})) + # append op_desc in grad_op_descs to target_block + for op_desc in grad_op_descs: + new_op_desc = target_block.desc.append_op() + new_op_desc.copy_from(op_desc) + + +def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): + for op_idx in range(start_op_idx, block.desc.op_size()): + op_desc = block.desc.op(op_idx) + if op_desc.has_attr("sub_block"): + sub_block = block.program.block(op_desc.block_attr("sub_block")) + _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map) + new_vars = set() + # create new gradient variables + for grad_var_name in op_desc.output_arg_names(): + grad_var_name = grad_var_name.encode("ascii") + if block.desc.has_var_recursive( + grad_var_name) or grad_var_name == core.empty_var_name(): + continue + block.desc.var(grad_var_name) + new_vars.add(grad_var_name) + if not grad_to_var.has_key(grad_var_name): + continue + grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) + # infer_shape and infer_type + op_desc.infer_var_type(block.desc) + op_desc.infer_shape(block.desc) + for arg in op_desc.output_arg_names(): + if arg in new_vars: + _infer_var_data_type_(arg, block) + + +def append_backward(loss, parameter_list=None, no_grad_set=None): """ Create and add gradient Operators in BlockDesc to compute gradients of `loss` for parameters in parameter_list :param loss: an variable generated by cost function. :type loss: Variable - :param no_grad_set: variable that should not create gradient - :type no_grad_set: set + :param no_grad_dict: variable that should not create gradient + :type no_grad_dict: set :param parameter_list: parameters that need to compute gradient and update to optimize the lost. :type: list @@ -20,35 +212,53 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): """ assert isinstance(loss, framework.Variable) + program = loss.block.program + no_grad_dict = dict() if no_grad_set is None: - program = loss.block.program assert isinstance(program, framework.Program) - no_grad_set = list() for block in program.blocks: assert isinstance(block, framework.Block) + block_no_grad_set = set() for var in block.vars.itervalues(): assert isinstance(var, framework.Variable) if var.stop_gradient: - no_grad_set.append(var.name) - no_grad_set = set(no_grad_set) + block_no_grad_set.add(_append_grad_suffix_(var.name)) + no_grad_dict[block.idx] = block_no_grad_set + elif isinstance(no_grad_set, set): + no_grad_dict = {0: no_grad_set} + else: + raise ValueError("'no_grad_set' should be a set or None.") + + grad_info_map = dict() + root_block = program.block(0) + + fwd_op_num = root_block.desc.op_size() + current_block_idx = program.current_block_idx + grad_to_var = dict() + + _append_backward_ops_(loss, root_block, root_block, no_grad_dict, + grad_to_var) + _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) + + program.current_block_idx = current_block_idx + program.sync_with_cpp() - param_grad_map = loss.block.program.append_backward(loss, no_grad_set) if parameter_list is not None: parameters = parameter_list else: - params = loss.block.program.global_block().all_parameters() + params = program.global_block().all_parameters() parameters = [param.name for param in params] params_and_grads = [] for param in parameters: - if param not in param_grad_map: + if param not in grad_info_map: raise ValueError("param %s is not in map" % param) - grad_info = param_grad_map[param] - grad_block = loss.block.program.block(grad_info[1]) + grad_info = grad_info_map[param] + grad_block = grad_info[1] if not grad_block.has_var(grad_info[0]): raise ValueError("grad block[{0}] did not have grad var {1}".format( grad_info[1], grad_info[0])) # Get the param var from the global block - param_var = loss.block.program.global_block().var(param) + param_var = program.global_block().var(param) grad_var = grad_block.var(grad_info[0]) if loss.block.has_var(grad_info[0]): params_and_grads.append((param_var, grad_var)) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 111937f59c3ab05e5917a79ca7e1f81f59747fc3..49ece7b725e318d7526d58fe54c97cbe20200a7d 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -95,7 +95,9 @@ class DistributeTranspiler: """ if program is None: program = default_main_program() + self.program = program self.trainers = trainers + self.optimize_ops = optimize_ops self._optimize_distributed( optimize_ops, program, @@ -156,9 +158,10 @@ class DistributeTranspiler: attrs={"endpoints": pserver_endpoints, "epmap": epmap}) - def get_trainer_program(optimize_ops, program): + def get_trainer_program(self): # remove optimize ops and add a send op to main_program - program.global_block().delete_ops(optimize_ops) + self.program.global_block().delete_ops(self.optimize_ops) + return self.program def _create_var_for_trainers(self, block, var, trainers): var_list = [] @@ -210,7 +213,6 @@ class DistributeTranspiler: if opt_op.inputs.has_key("Grad"): if opt_op.inputs["Grad"].name in grad_var_names: - print "appending ", opt_op.type, opt_op.inputs optimize_sub_program.global_block().append_op( type=opt_op.type, inputs=opt_op.inputs, diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index add854306ea7fa527943de871d2716cd2aa9f530..b66a8bce5f4f15539007876c113afd3f878b00bc 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -663,7 +663,7 @@ class Block(object): end = list(self.ops).index(ops[-1]) except Exception, e: raise e - self.desc.remove_op(start, end) + self.desc.remove_op(start, end + 1) def prepend_op(self, *args, **kwargs): op_desc = self.desc.prepend_op() @@ -846,9 +846,11 @@ class Program(object): self.sync_with_cpp() return param_to_grad_info - def create_block(self): + def create_block(self, parent_idx=None): new_block_idx = len(self.blocks) - self.desc.append_block(self.current_block().desc) + parent = self.current_block() if parent_idx is None else self.block( + parent_idx) + self.desc.append_block(parent.desc) self.current_block_idx = new_block_idx self.blocks.append(Block(self, self.current_block_idx)) return self.current_block() diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index c56a531ed531cf0219e94854ba66c7399e003292..ff3e5315a2c2b115e4ba563f60de4139f248e93a 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -1,7 +1,7 @@ from collections import defaultdict import framework -from backward import append_backward_ops +from backward import append_backward from framework import unique_name, program_guard from initializer import Constant from layer_helper import LayerHelper @@ -194,10 +194,10 @@ class Optimizer(object): no_grad_set=None): """Add operations to minimize `loss` by updating `parameter_list`. - This method combines interface `append_backward_ops()` and + This method combines interface `append_backward()` and `create_optimization_pass()` into one. """ - params_grads = append_backward_ops(loss, parameter_list, no_grad_set) + params_grads = append_backward(loss, parameter_list, no_grad_set) params_grads = append_gradient_clip_ops(params_grads) diff --git a/python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py similarity index 76% rename from python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py rename to python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py index 2680502efb91061be37a77fbe5b451960fdd15f7..20b4a8b34cd085ae51e6169f0d4eac58b7f3ffb2 100644 --- a/python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py +++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py @@ -38,35 +38,43 @@ train_reader = paddle.batch( place = fluid.CPUPlace() exe = fluid.Executor(place) + t = fluid.DistributeTranspiler() +# all parameter server endpoints list for spliting parameters pserver_endpoints = os.getenv("PSERVERS") +# server endpoint for current node +current_endpoint = os.getenv("SERVER_ENDPOINT") +# run as trainer or parameter server training_role = os.getenv("TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver -t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=1) +t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) if training_role == "PSERVER": - pserver_prog = t.get_pserver_program(pserver_endpoints, optimize_ops) + if not current_endpoint: + print("need env SERVER_ENDPOINT") + exit(1) + pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) exe.run(fluid.default_startup_program()) exe.run(pserver_prog) elif training_role == "TRAINER": + trainer_prog = t.get_trainer_program() feeder = fluid.DataFeeder(feed_list=[images, label], place=place) exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): accuracy.reset(exe) + batch_id = 0 for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), + loss, acc = exe.run(trainer_prog, feed=feeder.feed(data), fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) - # print loss, acc - if loss < 10.0 and pass_acc > 0.9: - # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. - exit(0) + if batch_id % 100 == 0: + print("batch_id %d, loss: %f, acc: %f" % + (batch_id, loss, pass_acc)) + batch_id += 1 pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) else: print("environment var TRAINER_ROLE should be TRAINER os PSERVER") - -exit(1) diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index 8dbfbd547a6677517f028997e6269709aac43b67..b77d2b1268f27c5ec3c34839aaad9b75f0132c2e 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -4,7 +4,7 @@ import random import itertools import paddle.v2.fluid.core as core import collections -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward from paddle.v2.fluid.op import Operator from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.framework import Program, OpProtoHolder @@ -491,7 +491,7 @@ class OpTest(unittest.TestCase): op_loss.desc.infer_var_type(block.desc) op_loss.desc.infer_shape(block.desc) - param_grad_list = append_backward_ops( + param_grad_list = append_backward( loss=loss, parameter_list=input_to_check, no_grad_set=no_grad_set) feed_dict = { diff --git a/python/paddle/v2/fluid/tests/test_array_read_write_op.py b/python/paddle/v2/fluid/tests/test_array_read_write_op.py index f6120aedecf1015c279b8f218f5e37f2e598ab91..01321de8eac34d562d99726b1f4125d1932ab40f 100644 --- a/python/paddle/v2/fluid/tests/test_array_read_write_op.py +++ b/python/paddle/v2/fluid/tests/test_array_read_write_op.py @@ -2,7 +2,7 @@ import unittest import paddle.v2.fluid.core as core import paddle.v2.fluid.layers as layers from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward from paddle.v2.fluid.framework import default_main_program import numpy @@ -64,7 +64,7 @@ class TestArrayReadWrite(unittest.TestCase): total_sum = layers.sums(input=[a_sum, x_sum]) total_sum_scaled = layers.scale(x=total_sum, scale=1 / 6.0) - append_backward_ops(total_sum_scaled) + append_backward(total_sum_scaled) g_vars = map(default_main_program().global_block().var, [each_x.name + "@GRAD" for each_x in x]) diff --git a/python/paddle/v2/fluid/tests/test_conditional_block.py b/python/paddle/v2/fluid/tests/test_conditional_block.py index 2b9d8f351a2836cd723d629d4790de1e068d0ea3..7d815123f3454d1457f59202219f9a93bf3d8c31 100644 --- a/python/paddle/v2/fluid/tests/test_conditional_block.py +++ b/python/paddle/v2/fluid/tests/test_conditional_block.py @@ -3,7 +3,7 @@ import paddle.v2.fluid.layers as layers import paddle.v2.fluid.core as core from paddle.v2.fluid.framework import default_startup_program, default_main_program from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward import numpy @@ -26,7 +26,7 @@ class ConditionalBlock(unittest.TestCase): outs = exe.run(feed={'X': x}, fetch_list=[out])[0] print outs loss = layers.mean(x=out) - append_backward_ops(loss=loss) + append_backward(loss=loss) outs = exe.run( feed={'X': x}, fetch_list=[ diff --git a/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py index 5fdabbcf889448114ac4e55e7944cb6c57ba5f3c..c552cb033f1ec8f5843490083edee7b2762b5703 100644 --- a/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py +++ b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py @@ -4,7 +4,7 @@ import numpy import paddle.v2.fluid.layers as layers from paddle.v2.fluid.framework import Program, program_guard from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward class TestCPULoDTensorArrayOps(unittest.TestCase): @@ -170,7 +170,7 @@ class TestCPULoDTensorArrayOpGrad(unittest.TestCase): mean = layers.mean(x=result) - append_backward_ops(mean) + append_backward(mean) tensor = core.LoDTensor() tensor.set(numpy.arange(10).reshape(10, 1).astype('float32'), place) diff --git a/python/paddle/v2/fluid/tests/test_optimizer.py b/python/paddle/v2/fluid/tests/test_optimizer.py index 29694be58bce0eb41b05439da35ef07a542ef12a..1eadb7d912629024ee21e30b0a5fa4910bb96e06 100644 --- a/python/paddle/v2/fluid/tests/test_optimizer.py +++ b/python/paddle/v2/fluid/tests/test_optimizer.py @@ -2,7 +2,7 @@ import unittest import paddle.v2.fluid.framework as framework import paddle.v2.fluid.optimizer as optimizer -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward class TestOptimizer(unittest.TestCase): @@ -102,7 +102,7 @@ class TestMomentumOptimizer(unittest.TestCase): dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase): learning_rate = 0.01 momentum_optimizer = self.MockMomentum( learning_rate=learning_rate, momentum=0.2, use_nesterov=True) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -209,7 +209,7 @@ class TestAdagradOptimizer(unittest.TestCase): learning_rate = 0.01 adagrad_optimizer = self.MockAdagrad( learning_rate=learning_rate, epsilon=1.0e-6) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, @@ -269,7 +269,7 @@ class TestAdamOptimizer(unittest.TestCase): learning_rate = 0.01 adam_optimizer = self.MockAdam( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adam_optimizer.get_accumulators()), 0) opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, @@ -331,7 +331,7 @@ class TestAdamaxOptimizer(unittest.TestCase): learning_rate = 0.01 adamax_optimizer = self.MockAdamax( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, @@ -390,7 +390,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): learning_rate = 0.01 decayed_adagrad_optimizer = self.MockDecayedAdagrad( learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) opts = decayed_adagrad_optimizer.create_optimization_pass( diff --git a/python/paddle/v2/fluid/tests/test_recurrent_op.py b/python/paddle/v2/fluid/tests/test_recurrent_op.py index e38c763ddbcc5c8410f41d062c05499333a3ee55..84f4e36fa7312fbcb96cc66ff26e234c3016df30 100644 --- a/python/paddle/v2/fluid/tests/test_recurrent_op.py +++ b/python/paddle/v2/fluid/tests/test_recurrent_op.py @@ -3,7 +3,7 @@ import unittest import paddle.v2.fluid.layers as layers from paddle.v2.fluid.framework import Program, grad_var_name from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward import numpy as np import paddle.v2.fluid.core as core @@ -177,7 +177,7 @@ class RecurrentOpTest1(unittest.TestCase): def test_backward(self): self.check_forward() - append_backward_ops(self.output) + append_backward(self.output) ana_grad = [np.array(x) for x in self.backward()] diff --git a/python/paddle/v2/fluid/tests/test_regularizer.py b/python/paddle/v2/fluid/tests/test_regularizer.py index 24baf55e90c98f39bab926e8c85a791eee5ed4a4..890c881a126a32344128652691c6cad45e02e82d 100644 --- a/python/paddle/v2/fluid/tests/test_regularizer.py +++ b/python/paddle/v2/fluid/tests/test_regularizer.py @@ -3,7 +3,7 @@ import unittest import paddle.v2.fluid.framework as framework import paddle.v2.fluid.optimizer as optimizer import paddle.v2.fluid.regularizer as regularizer -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward class TestL2DecayRegularizer(unittest.TestCase): @@ -33,7 +33,7 @@ class TestL2DecayRegularizer(unittest.TestCase): dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads) @@ -70,7 +70,7 @@ class TestL1DecayRegularizer(unittest.TestCase): dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) - params_grads = append_backward_ops(mean_out) + params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads) diff --git a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py index 8f5774835e02191a068e86ea56f3f877c464a391..7c136f6360ce73a7c532b5486e544796e6853bcb 100644 --- a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py +++ b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py @@ -12,7 +12,7 @@ class TestReorderLoDTensor(unittest.TestCase): new_dat = fluid.layers.reorder_lod_tensor_by_rank( x=dat, rank_table=table) loss = fluid.layers.mean(x=new_dat) - fluid.backward.append_backward_ops(loss=loss) + fluid.backward.append_backward(loss=loss) cpu = fluid.CPUPlace() exe = fluid.Executor(cpu) diff --git a/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py b/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py index 9999165ed509aa40f31f26aa676f381561bd0016..d1bb20f37a3785f70bee072b9df282bba4012c16 100644 --- a/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py +++ b/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py @@ -2,7 +2,7 @@ import unittest from paddle.v2.fluid.framework import Program from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward import numpy as np import paddle.v2.fluid.core as core diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index 86db4c64b493d94cc675ed4bcee7e2925fef1977..be1588fc2d09fa58882425eb3d080ef1560ebc79 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -2,7 +2,7 @@ import unittest import paddle.v2.fluid.core as core from paddle.v2.fluid.executor import Executor import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward from paddle.v2.fluid.framework import default_main_program import numpy @@ -35,7 +35,7 @@ class TestShrinkRNNMemory(unittest.TestCase): self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2])) mem3_mean = layers.mean(x=mem3) - append_backward_ops(loss=mem3_mean) + append_backward(loss=mem3_mean) x_grad = exe.run( feed={'x': tensor}, fetch_list=[main_program.global_block().var('x@GRAD')])[0] diff --git a/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py index 8cdd59ff3cc7deb57252fc5218d239f86016cb9c..2e4defd55d75c2012f39bea30a6c4de12528e77c 100644 --- a/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py +++ b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py @@ -4,7 +4,7 @@ import numpy as np import paddle.v2.fluid.layers as layers from paddle.v2.fluid.framework import Program, program_guard from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward class TestCPULoDTensorArrayOps(unittest.TestCase): @@ -133,7 +133,7 @@ class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): in_true=out_true, in_false=out_false, mask=y, x=x, level=level) mean = layers.mean(x=out) - append_backward_ops(mean) + append_backward(mean) tensor = core.LoDTensor() tensor.set(np.arange(10).reshape(10, 1).astype('float32'), place) diff --git a/python/paddle/v2/fluid/tests/test_while_op.py b/python/paddle/v2/fluid/tests/test_while_op.py index 033b03a4957131e1155c61e8ed2f10eefb23fda4..7c5593cc5e5a66d4ccb237e3706ff3e544adf033 100644 --- a/python/paddle/v2/fluid/tests/test_while_op.py +++ b/python/paddle/v2/fluid/tests/test_while_op.py @@ -2,7 +2,7 @@ import unittest import paddle.v2.fluid.layers as layers from paddle.v2.fluid.executor import Executor import paddle.v2.fluid.core as core -from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.backward import append_backward import numpy @@ -46,7 +46,7 @@ class TestWhileOp(unittest.TestCase): sum_result = layers.array_read(array=mem_array, i=i) loss = layers.mean(x=sum_result) - append_backward_ops(loss) + append_backward(loss) cpu = core.CPUPlace() exe = Executor(cpu)