diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 26ee5131c1de6c001574fd3448b1225b0c0a923f..7e061c460682824625b8c6202fdce0f833a5cc11 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -120,13 +120,12 @@ void Executor::Close() { void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id) { + VLOG(3) << "Creating Variables for block " << block_id; auto& global_block = pdesc.Block(block_id); - const Scope* ancestor_scope = scope; while (ancestor_scope->parent()) { ancestor_scope = ancestor_scope->parent(); } - if (ancestor_scope != scope) { for (auto& var : global_block.AllVars()) { if (var->Name() == framework::kEmptyVarName) { @@ -196,11 +195,12 @@ void Executor::ReleaseTrainer(std::shared_ptr trainer) { void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars, const std::vector& skip_ref_cnt_vars, - bool force_disable_gc) { + bool force_disable_gc, bool keep_kid_scopes) { platform::RecordBlock b(block_id); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); - RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); + RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars, + keep_kid_scopes); } // Check whether the block already has feed operators and feed_holder. @@ -465,6 +465,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, // the sub scopes it created should not be dropped immediately, because // while_grad_op will use some variables created during while_op run, so // we need to keep the kids and wait for the outer executor to drop them. + scope->DropKids(); } } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 33e0587ebf5089653461f7e41612409dbb71f5ab..cc663f220955540411eacff2c2b0704784cf0427 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -71,12 +71,18 @@ class Executor { * @param * ProgramDesc * Scope + * block_id + * create_local_scope + * create_vars + * skip_ref_cnt_vars + * force_disable_gc + * keep_kid_scopes */ void Run(const ProgramDesc& prog, Scope* scope, int block_id, bool create_local_scope = true, bool create_vars = true, const std::vector& skip_ref_cnt_vars = std::vector(), - bool force_disable_gc = false); + bool force_disable_gc = false, bool keep_kid_scopes = false); // This API is very slow. void Run(const ProgramDesc& program, Scope* scope, diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index afafff5218ccf95fdc4baf7282d4f2757a74ac9c..6b83c04732357752f648a0273d6375233f629c2e 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -93,6 +93,11 @@ const Scope* Scope::FindScope(const Variable* var) const { return FindScopeInternal(var); } +const Scope* Scope::FindScope(const std::string& name) const { + SCOPE_VARS_READER_LOCK + return FindScopeInternal(name); +} + void Scope::DropKids() { SCOPE_KIDS_WRITER_LOCK for (Scope* s : kids_) delete s; @@ -174,6 +179,13 @@ const Scope* Scope::FindScopeInternal(const Variable* var) const { return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); } +const Scope* Scope::FindScopeInternal(const std::string& name) const { + if (vars_.find(name) != vars_.end()) { + return this; + } + return (parent_ == nullptr) ? nullptr : parent_->FindScope(name); +} + void Scope::RenameInternal(const std::string& origin_name, const std::string& new_name) const { auto origin_it = vars_.find(origin_name); @@ -196,7 +208,9 @@ Variable* Scope::FindVarInternal(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const { auto it = vars_.find(name); - if (it != vars_.end()) return it->second.get(); + if (it != vars_.end()) { + return it->second.get(); + } return nullptr; } diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index d3e2f33d2e3788c7ed1ff9a77d2936ca0d32c767..db7010ecceb3e7c39cdfd78c5e82074dba199fc7 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -85,6 +85,9 @@ class Scope { /// Find the scope or an ancestor scope that contains the given variable. const Scope* FindScope(const Variable* var) const; + /// Find the scope or an ancestor scope that contains the given variable name. + const Scope* FindScope(const std::string& name) const; + void DeleteScope(Scope* scope) const; /// Drop all kids scopes belonged to this scope. @@ -125,6 +128,9 @@ class Scope { // Called by FindScope. const Scope* FindScopeInternal(const Variable* var) const; + // Called by FindScope. + const Scope* FindScopeInternal(const std::string& name) const; + // Called by Rename. void RenameInternal(const std::string& origin_name, const std::string& new_name) const; diff --git a/paddle/fluid/framework/scope_test.cc b/paddle/fluid/framework/scope_test.cc index ebf8178a8319cd33f2cc5eacb95b163043c986b5..26817fc558dfab6926b67ee744b0a2ef548b4ffb 100644 --- a/paddle/fluid/framework/scope_test.cc +++ b/paddle/fluid/framework/scope_test.cc @@ -53,7 +53,9 @@ TEST(Scope, FindScope) { Variable* v = s.Var("a"); EXPECT_EQ(&s, s.FindScope(v)); + EXPECT_EQ(&s, s.FindScope("a")); EXPECT_EQ(&s, ss.FindScope(v)); + EXPECT_EQ(&s, ss.FindScope("a")); } TEST(Scope, GetAllNames) { diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index d1751ba6fa2d70fac555eeaed9adc1aad8caccf6..c74c4ebbd886400ff0c7aeb38de6be176ea67fe4 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/controlflow/conditional_block_op.h" +#include "paddle/fluid/operators/assign_op.h" namespace paddle { namespace operators { @@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp { scopes->resize(1); scopes->front() = &scope.NewScope(); auto &cur_scope = *scopes->front(); - framework::Executor exec(dev_place); auto *block = Attr("sub_block"); + VLOG(3) << "Conditional block.idx = " << block->ID() + << ", scope = " << &cur_scope; auto &skip_vars = Attr>(ConditionalOp::kSkipEagerDeletionVars); exec.Run(*block->Program(), &cur_scope, block->ID(), false, true, - skip_vars); + skip_vars, /* force_disable_gc */ false, + /* keep_kid_scopes */ true); } } }; @@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp { } if (need_run) { + const auto &inputs = Inputs(ConditionalOp::kInputs); + const auto &outside_grads = + Outputs(framework::GradVarName(ConditionalOp::kInputs)); + + std::vector inside_grads; + inside_grads.reserve(inputs.size()); + for (auto &in : inputs) { + inside_grads.emplace_back(framework::GradVarName(in)); + } + auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope)); - PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + PADDLE_ENFORCE_NE(scope_var, nullptr, + platform::errors::InvalidArgument( + "Scope must be set in conditional block op")); auto &scopes = scope_var->Get>(); + PADDLE_ENFORCE_GT(scopes.size(), 0, + platform::errors::InvalidArgument( + "Scope must be set in conditional block op")); framework::Scope &cur_scope = *scopes[0]; framework::Executor exec(dev_place); auto *block = Attr("sub_block"); - const auto &ins = Inputs(ConditionalOp::kInputs); - const auto &d_ins = - Outputs(framework::GradVarName(ConditionalOp::kInputs)); - const auto &conds = Inputs(ConditionalOp::kCondition); - const auto &d_conds = - Outputs(framework::GradVarName(ConditionalOp::kCondition)); - - std::vector ins_conds_grads; - ins_conds_grads.reserve(ins.size() + conds.size()); - for (auto &in : ins) { - ins_conds_grads.emplace_back(framework::GradVarName(in)); - } - for (auto &cond : conds) { - ins_conds_grads.emplace_back(framework::GradVarName(cond)); - } - + VLOG(3) << "Conditional Grad block.idx = " << block->ID() + << ", scope = " << &cur_scope; exec.Run(*block->Program(), &cur_scope, block->ID(), false, true, - ins_conds_grads); - - AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(), - ins.size(), d_ins); + inside_grads, /* force_disable_gc */ false, + /* keep_kid_scopes */ false); - AssignLocalGradientToGlobal(dev_place, cur_scope, - ins_conds_grads.data() + ins.size(), - conds.size(), d_conds); + AssignLocalGradientToParentScope(dev_place, cur_scope, scope, + inside_grads, outside_grads); } } private: - void AssignLocalGradientToGlobal( + void AssignLocalGradientToParentScope( const platform::Place &place, const framework::Scope &cur_scope, - const std::string *p_grad_names, size_t p_grad_names_num, - const std::vector &pg_names) const { - for (size_t i = 0; i < p_grad_names_num; ++i) { - auto out_grad_name = pg_names[i]; - const auto &in_grad_name = p_grad_names[i]; - auto *in_var = cur_scope.FindVar(in_grad_name); - if (in_var == nullptr) { + const framework::Scope &parent_scope, + const std::vector &inside_grads, + const std::vector &outside_grads) const { + for (size_t i = 0; i < outside_grads.size(); ++i) { + const std::string &outside_grad_name = outside_grads[i]; + const std::string &inside_grad_name = inside_grads[i]; + VLOG(4) << "inside_grad_name = " << inside_grad_name + << ", outside_grad_name = " << outside_grad_name; + framework::Variable *inside_var = + cur_scope.FindLocalVar(inside_grad_name); + if (inside_var == nullptr) { continue; } - auto new_in_grad_name = cur_scope.Rename(in_grad_name); - auto assign = framework::OpRegistry::CreateOp( - "assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}}, - framework::AttributeMap{}); - assign->Run(cur_scope, place); - cur_scope.Rename(new_in_grad_name, in_grad_name); + framework::Variable *outside_var = + parent_scope.FindVar(outside_grad_name); + if (outside_var == nullptr) { + continue; + } + platform::DeviceContext *dev_ctx = + platform::DeviceContextPool::Instance().Get(place); + framework::VisitVarType(*inside_var, + AssignFunctor(outside_var, *dev_ctx)); } } }; @@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition)); - if (context->HasInputs(ConditionalOp::kInputs)) { - PADDLE_ENFORCE( - context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))); + if (context->HasInputs(ConditionalOp::kInputs) && + context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))) { context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs), context->GetInputsDim(ConditionalOp::kInputs)); } - if (context->HasOutputs( - framework::GradVarName(ConditionalOp::kCondition))) { - context->SetOutputsDim(framework::GradVarName(ConditionalOp::kCondition), - context->GetInputsDim(ConditionalOp::kCondition)); - } } }; @@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { this->OutputGrad(ConditionalOp::kOutputs)); grad_op->SetInput(ConditionalOp::kScope, this->Output(ConditionalOp::kScope)); - grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition), - this->InputGrad(ConditionalOp::kCondition, false)); grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), this->InputGrad(ConditionalOp::kInputs, false)); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 37e062982e1f6e919eb78c79fd21c64c1184b20d..3e2bd21d7029c7e24f8033fcfa93350e115bf9fd 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -15,6 +15,7 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/assign_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index fa6ae65a0e75d365cccd6d8b5efd1eac3132e978..272cf3573fb2f27c62bce86f0e97b8e567b245ae 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -128,19 +128,19 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { in_vars[1]->IsType()) { auto &in_0 = in_vars[0]->Get(); auto &in_1 = in_vars[1]->Get(); - - auto length = in_0.numel(); - if (length && in_0.IsInitialized() && in_1.IsInitialized()) { + int64_t length_0 = in_0.numel(); + int64_t length_1 = in_1.numel(); + if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) { auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); auto in_0_e = EigenVector::Flatten(in_0); auto in_1_e = EigenVector::Flatten(in_1); result.device(place) = in_0_e + in_1_e; - } else if (length && in_0.IsInitialized()) { + } else if (length_0 && in_0.IsInitialized()) { auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); result.device(place) = EigenVector::Flatten(in_0); - } else if (length && in_1.IsInitialized()) { + } else if (length_1 && in_1.IsInitialized()) { auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); result.device(place) = EigenVector::Flatten(in_1); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 3eca39c2b180439a3ea1647b616482198352f9d6..54e7066ba5f97f3ae4720149dc98a52165283319 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -321,7 +321,7 @@ def _append_grad_suffix_(name): return cpt.to_text(name) + core.grad_var_suffix() -def _addup_repetitive_outputs_(op_descs): +def _addup_repetitive_outputs_(op_descs, block_idx): """ In backward part, an variable may be the output of more than one ops. And one op may yield its multiple outputs to the same variable. @@ -358,7 +358,7 @@ def _addup_repetitive_outputs_(op_descs): renamed_var_start_idx[var_name] = idx else: if len(renamed_vars[var_name]) == 1: - new_name = var_name + "@RENAME@" + \ + new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 # rename original var_name @@ -384,7 +384,7 @@ def _addup_repetitive_outputs_(op_descs): for x in arg_names[:arg_idx] ] + arg_names[arg_idx:] - new_name = var_name + "@RENAME@" + \ + new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] += 1 arg_names[arg_idx] = new_name @@ -733,7 +733,7 @@ def _append_backward_ops_with_checkpoints_( grad_to_var.update(op_grad_to_var) # 3.d. add sum op for repetitive_outputs - grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) # 4) remove no grad branch as it is in _remove_no_grad_branch_ grad_op_descs = _remove_no_grad_branch_(grad_op_descs, no_grad_dict[block.idx]) @@ -741,6 +741,35 @@ def _append_backward_ops_with_checkpoints_( return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments +def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set): + """ + Get output vars in subblock which will be assigned to parent block. + It is used to find the grad path in subblock + """ + assert sub_block_op_desc.has_attr( + "sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id( + "sub_block") + # TODO(huihuangzheng): add support for recurrent op and while op + if sub_block_op_desc.type == "conditional_block": + sub_outputs = [] + sub_assign_to_out_ops = [] + for var in sub_block_op_desc.output_arg_names: + for op_desc in sub_block.ops: + if op_desc.type == "assign" and var in op_desc.output_arg_names: + sub_assign_to_out_ops.append(op_desc) + sub_outputs.extend([ + sub_block.var(name) for name in op_desc.input_arg_names + ]) + sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [], + no_grad_set) + # TODO better way than finding in list + for op_desc in sub_assign_to_out_ops: + if op_desc not in sub_block_op_path: + sub_block_op_path.append(op_desc) + return sub_block_op_path + return sub_block.ops + + def _append_backward_ops_(block, ops, target_block, @@ -775,6 +804,8 @@ def _append_backward_ops_(block, # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program + + # add grad_op_desc by reversed ops for op in reversed(ops): grad_sub_block_list = [] # If the op has its own sub-block, deal with the sub-block first @@ -785,7 +816,9 @@ def _append_backward_ops_(block, # see follwing comments for why set None here. pre_input_grad_names_set = copy.copy(input_grad_names_set) input_grad_names_set = None - _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, + sub_block_path = _get_sub_block_path(sub_block, op, + no_grad_dict[sub_block.idx]) + _append_backward_ops_(sub_block, sub_block_path, grad_sub_block, no_grad_dict, grad_to_var, callbacks, input_grad_names_set) input_grad_names_set = pre_input_grad_names_set @@ -825,10 +858,8 @@ def _append_backward_ops_(block, grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) - # add grad_op_desc by reversed ops - # sum parameter's gradients' var given multiple var gradient - grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) # if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all inputs of the grad op are in no_grad_set, just remove this op @@ -841,6 +872,7 @@ def _append_backward_ops_(block, grad_op_descs = [ op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops ] + # append op_desc in grad_op_descs to target_block op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() backward = core.op_proto_and_checker_maker.OpRole.Backward diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 87dd95bf1442a520994abad92ad33ae35a46c8ca..45933463f6ef96d24906178bc5a4725d76097e49 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -22,7 +22,7 @@ from ..framework import Program, Variable, Operator from ..layer_helper import LayerHelper, unique_name from ..initializer import force_init_on_cpu from .nn import logical_and, logical_not, logical_or -from .utils import assert_same_structure, flatten, map_structure +from .utils import assert_same_structure, map_structure import numpy import warnings import six @@ -1710,7 +1710,6 @@ class ConditionalBlock(object): param_list = [ parent_block._var_recursive(each_name) for each_name in params - if each_name not in input_set ] out_list = [] @@ -1755,7 +1754,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): helper = LayerHelper('cond', **locals()) true_output = None false_output = None - copy_to_global_func = lambda var: copy_var_to_parent_block(var, helper) + copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper) if true_fn is not None: if not callable(true_fn): raise TypeError("The true_fn in cond must be callable") @@ -1763,7 +1762,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): with true_cond_block.block(): origin_true_output = true_fn() if origin_true_output is not None: - true_output = map_structure(copy_to_global_func, + true_output = map_structure(copy_to_parent_func, origin_true_output) if false_fn is not None: if not callable(false_fn): @@ -1773,7 +1772,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): with false_cond_block.block(): origin_false_output = false_fn() if origin_false_output is not None: - false_output = map_structure(copy_to_global_func, + false_output = map_structure(copy_to_parent_func, origin_false_output) if true_output is None and false_output is None: diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 370f95b1ac529f7e0813a857eb3a99b1db271a30..acbe97dd1214cf94aeb88ff90273d54c67112c8d 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -21,11 +21,12 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers import paddle.fluid.framework as framework +from paddle.fluid.backward import append_backward from paddle.fluid.executor import Executor from paddle.fluid.framework import Program, program_guard -class TestCond(unittest.TestCase): +class TestCondInputOutput(unittest.TestCase): def test_return_single_var(self): """ pseudocode: @@ -220,5 +221,59 @@ class TestCond(unittest.TestCase): str(e.exception)) +class TestCondNestedControlFlow(unittest.TestCase): + def test_cond_inside_cond(self): + """ + pseudocode: + for i in range(1, 10): + a = 2 * i + if i < 5: + if i >= 3: + return a + a + else: + return a - a + else: + if i < 8: + return a * a + else: + return a / a + """ + + def less_than_branch(i, a): + return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a), + lambda: layers.elementwise_sub(a, a)) + + def greater_equal_branch(i, a): + return layers.cond(i < 8.0, lambda: layers.elementwise_mul(a, a), + lambda: layers.elementwise_div(a, a)) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + i = fluid.data(name="i", shape=[1], dtype='float32') + a = 2.0 * i + out = layers.cond(i < 5.0, lambda: less_than_branch(i, a), + lambda: greater_equal_branch(i, a)) + mean = layers.mean(out) + append_backward(mean) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + for feed_i in range(0, 10): + expected_a = 2.0 * feed_i + if feed_i < 5: + expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0 + expected_a_grad = 2.0 if feed_i >= 3 else 0.0 + else: + expected_ret = expected_a * expected_a if feed_i < 8 else 1.0 + expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0 + ret = exe.run(main_program, + feed={'i': np.full((1), feed_i, np.float32)}, + fetch_list=[out.name, a.grad_name]) + self.assertEqual(ret[0][0], expected_ret) + self.assertEqual(ret[1][0], expected_a_grad) + + if __name__ == '__main__': unittest.main()