From 4624d7c642ef128cd96f64bd344a367b2683a5ca Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Mon, 13 May 2019 16:25:15 +0800 Subject: [PATCH] test=develop, add gradient sort backward strategy (#17125) * test=develop, add gradient sort backward strategy * test=develop, fix test by add FLAGS_cudnn_deterministic on new tests --- paddle/fluid/imperative/backward_strategy.h | 38 ++ paddle/fluid/imperative/layer.cc | 116 +++++- paddle/fluid/imperative/layer.h | 8 +- paddle/fluid/imperative/type_defs.h | 7 + paddle/fluid/pybind/pybind.cc | 17 +- python/paddle/fluid/dygraph/__init__.py | 4 + .../paddle/fluid/dygraph/backward_strategy.py | 19 + python/paddle/fluid/framework.py | 13 +- .../fluid/tests/unittests/CMakeLists.txt | 6 + .../tests/unittests/test_imperative_basic.py | 56 +++ .../tests/unittests/test_imperative_deepcf.py | 28 ++ .../tests/unittests/test_imperative_gan.py | 48 +++ .../tests/unittests/test_imperative_gnn.py | 29 +- .../test_imperative_mnist_sorted_gradient.py | 149 +++++++ ...test_imperative_ptb_rnn_sorted_gradient.py | 165 ++++++++ .../test_imperative_resnet_sorted_gradient.py | 230 +++++++++++ ..._imperative_transformer_sorted_gradient.py | 367 ++++++++++++++++++ 17 files changed, 1275 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/imperative/backward_strategy.h create mode 100644 python/paddle/fluid/dygraph/backward_strategy.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py diff --git a/paddle/fluid/imperative/backward_strategy.h b/paddle/fluid/imperative/backward_strategy.h new file mode 100644 index 0000000000..9ff07d6d79 --- /dev/null +++ b/paddle/fluid/imperative/backward_strategy.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Created by Jiabin on 2019-04-25. +// +#pragma once +#ifndef PADDLE_BACKWARDSTRATEGY_H +#define PADDLE_BACKWARDSTRATEGY_H + +#endif // PADDLE_BACKWARDSTRATEGY_H + +namespace paddle { +namespace imperative { +namespace detail { + +class BackwardStrategy { + public: + /* DyGraph now support two kinds of backward strategy, one is sorted sum + * gradient, another is sum gradient once they are created */ + // TODO(jiabin): add more Strategy when we support + bool sorted_sum_gradient_{false}; +}; + +} // namespace detail +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index aa739a8972..2458b95448 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -101,26 +101,50 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) { boost::apply_visitor(func, place); } +void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { + PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(), + "Can't find %s in backward grad map", target->Name()); + std::pair>>& current = + bck_map->at(target); + std::sort( + current.second.begin(), current.second.end(), + [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + for (auto& var_pair : current.second) { + Variable* origin_grad = target->var_; + Variable* grad_to_add = var_pair.second->var_; + VLOG(2) << "add origin_grad: " << target->Name(); + VLOG(2) << "added grad: " << var_pair.second->Name() + << " trace id is: " << var_pair.first; + AddTo(grad_to_add, origin_grad, current.first); + delete grad_to_add; + var_pair.second = nullptr; + } +} + class Autograd { public: Autograd() {} - void RunBackward(VarBase* var) { + void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) { if (var->IsStopGradient()) { return; } VLOG(3) << "start autograd"; - + bck_map = new BackwardSumMap(); + grad_ref = new GradientRef(); std::deque ready; ready.push_back(var->PreOp()); - std::map dep_counts = ComputeDepCounts(var->PreOp()); + std::map dep_counts = + ComputeDepCounts(var->PreOp(), bck_stratedy); while (!ready.empty()) { OpBase* ready_op = ready.front(); ready.pop_front(); std::map> input_grads = - ready_op->ApplyGrad(); + ready_op->ApplyGrad(bck_map, grad_ref, bck_stratedy); for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) { const std::vector& ingrads = it->second; @@ -146,7 +170,8 @@ class Autograd { } private: - std::map ComputeDepCounts(OpBase* op) { + std::map ComputeDepCounts( + OpBase* op, const detail::BackwardStrategy& bck_stratedy) { std::map ret; std::deque queue; @@ -156,10 +181,25 @@ class Autograd { while (!queue.empty()) { OpBase* candidate = queue.front(); queue.pop_front(); + if (bck_stratedy.sorted_sum_gradient_) { + for (const auto& map : candidate->grad_output_vars_) { + for (const auto& it : map) { + for (const auto& vb : it.second) { + if (grad_ref->find(vb) == grad_ref->end()) { + grad_ref->insert(std::make_pair(vb, 1)); + } else { + // add ref count by 1 when we find grad_var can be generated by + // one grad_op + grad_ref->at(vb) += 1; + } + } + } + } + } for (auto it : candidate->pre_ops_) { for (OpBase* pre_op : it.second) { if (!pre_op) continue; - VLOG(5) << "op dep " << candidate->Type() << " trace id " + VLOG(2) << "op dep " << candidate->Type() << " trace id " << candidate->trace_id_ << " <---- " << it.first << " <---- " << pre_op->Type() << " trace id " << pre_op->trace_id_; if (visited.find(pre_op) == visited.end()) { @@ -172,6 +212,9 @@ class Autograd { } return ret; } + + BackwardSumMap* bck_map; + GradientRef* grad_ref; }; std::unique_ptr VarBase::NewVarBase(const platform::Place& dst_place, @@ -213,7 +256,9 @@ framework::LoDTensor& VarBase::GradValue() { return *(grads_->var_->GetMutable()); } -std::map> OpBase::ApplyGrad() { +std::map> OpBase::ApplyGrad( + BackwardSumMap* bck_map, GradientRef* grad_ref, + const detail::BackwardStrategy& bck_stratedy) { PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0, "%s has no backward implementation", Type()); @@ -313,13 +358,52 @@ std::map> OpBase::ApplyGrad() { PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { - framework::Variable* grad = outputs[i]->var_; - framework::Variable* orig_grad = origin_outputs[i]->var_; - VLOG(3) << "AddTo Called with orig_grad is: " - << origin_outputs[i]->name_ << " Grad to be added is " - << outputs[i]->name_; - AddTo(grad, orig_grad, place_); - delete grad; + // track outputs used by sum + if (bck_stratedy.sorted_sum_gradient_) { +#ifndef PADDLE_WITH_CUDA + VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " "; + VLOG(2) << origin_outputs[i] + ->var_->GetMutable() + ->data()[0]; + VLOG(2) << "outputs is : " << outputs[i]->Name() << " "; + VLOG(2) << outputs[i] + ->var_->GetMutable() + ->data()[0]; +#endif + if (bck_map->find(origin_outputs[i]) != bck_map->end()) { + VLOG(2) << "add sub grad to " << origin_outputs[i]->Name(); + bck_map->at(origin_outputs[i]) + .second.emplace_back( + std::pair(this->trace_id_, outputs[i])); + } else { + VLOG(2) << "insert new map for " << origin_outputs[i]->Name(); + std::pair>> + tmp(place_, {std::make_pair(this->trace_id_, outputs[i])}); + bck_map->insert(std::make_pair(origin_outputs[i], tmp)); + } + + PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(), + "Can't find %s in grad_reference count map", + origin_outputs[i]->Name()); + PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1, + "Backward error when calculate grad reference"); + if (grad_ref->at(origin_outputs[i]) > 1) { + VLOG(2) << "remove ref for " << origin_outputs[i]->Name(); + grad_ref->at(origin_outputs[i])--; + } else { + VLOG(2) << "Add grad for: " << origin_outputs[i]->Name(); + AddGradBySort(bck_map, origin_outputs[i]); + grad_ref->at(origin_outputs[i])--; + } + } else { + framework::Variable* grad = outputs[i]->var_; + framework::Variable* orig_grad = origin_outputs[i]->var_; + VLOG(2) << "AddTo Called with orig_grad is: " + << origin_outputs[i]->name_ << " Grad to be added is " + << outputs[i]->name_; + AddTo(grad, orig_grad, place_); + delete grad; + } } } } @@ -347,7 +431,7 @@ void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) { } } -void VarBase::RunBackward() { +void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { if (!pre_op_) return; VLOG(3) << "start backward"; @@ -360,7 +444,7 @@ void VarBase::RunBackward() { PADDLE_ENFORCE( grads_ == pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_); - Autograd().RunBackward(this); + Autograd().RunBackward(this, bck_stratedy); } void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 37488d381e..76d98640af 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -31,7 +31,7 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/operators/math/math_function.h" - +#include "paddle/fluid/imperative/backward_strategy.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { @@ -217,7 +217,7 @@ class VarBase { inline OpBase* PreOp() const { return pre_op_; } inline int PreOpOutIdx() const { return pre_op_out_idx_; } - void RunBackward(); + void RunBackward(const detail::BackwardStrategy& bck_stratedy); inline void ResetPreOp(OpBase* op) { if (op == pre_op_) { @@ -302,7 +302,9 @@ class PYBIND11_HIDDEN OpBase { } } - std::map> ApplyGrad(); + std::map> ApplyGrad( + BackwardSumMap* bck_map, GradientRef* grad_ref, + const detail::BackwardStrategy& bck_stratedy); inline std::string Type() const { return type_; } inline std::string GradOpType(size_t index) const { diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index c51ce931de..13d08cbb71 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include +#include +#include #include namespace paddle { @@ -27,6 +29,11 @@ class OpBase; typedef std::map> VarBasePtrMap; typedef std::map> ConstVarBasePtrMap; typedef std::map> OpBasePtrMap; +typedef std::unordered_map< + const VarBase*, + std::pair>>> + BackwardSumMap; // var_grad -> {place, {id -> var_grad@rename}} +typedef std::unordered_map GradientRef; } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 63d37223ca..43322b796b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -184,6 +184,18 @@ PYBIND11_MODULE(core, m) { m.def("print_mem_usage", []() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); }); + py::class_ backward_strategy( + m, "BackwardStrategy", R"DOC()DOC"); + backward_strategy.def(py::init()) + .def_property("sort_sum_gradient", + [](const imperative::detail::BackwardStrategy &self) { + return self.sorted_sum_gradient_; + }, + [](imperative::detail::BackwardStrategy &self, + bool sorted_sum_gradient) { + self.sorted_sum_gradient_ = sorted_sum_gradient; + }); + m.def("start_imperative_gperf_profiler", []() { imperative::StartProfile(); }); @@ -199,7 +211,10 @@ PYBIND11_MODULE(core, m) { const std::vector, const paddle::platform::CUDAPlace, bool, bool>()) .def("_run_backward", - [](imperative::VarBase &self) { self.RunBackward(); }) + [](imperative::VarBase &self, + const imperative::detail::BackwardStrategy &bckst) { + self.RunBackward(bckst); + }) .def("_grad_name", &imperative::VarBase::GradName) .def("_grad_value", &imperative::VarBase::GradValue) .def("_clear_gradient", &imperative::VarBase::ClearGradient) diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 9bb72ede30..7ab1dfdf76 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -38,6 +38,9 @@ from .checkpoint import * from . import learning_rate_scheduler from .learning_rate_scheduler import * +from . import backward_strategy +from .backward_strategy import * + __all__ = [] __all__ += layers.__all__ __all__ += base.__all__ @@ -47,3 +50,4 @@ __all__ += profiler.__all__ __all__ += parallel.__all__ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ +__all__ += backward_strategy.__all__ diff --git a/python/paddle/fluid/dygraph/backward_strategy.py b/python/paddle/fluid/dygraph/backward_strategy.py new file mode 100644 index 0000000000..bfcf66af31 --- /dev/null +++ b/python/paddle/fluid/dygraph/backward_strategy.py @@ -0,0 +1,19 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import core + +__all__ = ["BackwardStrategy"] + +BackwardStrategy = core.BackwardStrategy diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 892d026550..08b5789f0b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -529,8 +529,17 @@ class Variable(object): new_ivar = self._ivar._copy_to(core.CPUPlace(), True) return np.array(new_ivar.value().get_tensor()) - def backward(self): - self._ivar._run_backward() + def backward(self, backward_strategy=None): + from .dygraph import BackwardStrategy + if isinstance(backward_strategy, BackwardStrategy): + self._ivar._run_backward(backward_strategy) + elif backward_strategy is not None: + raise TypeError( + "only BackwardStrategy type should be passed in backward") + else: + backward_strategy = BackwardStrategy() + backward_strategy.sort_sum_gradient = False + self._ivar._run_backward(backward_strategy) def gradient(self): new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index aa4fc5ceb9..a8e3459d1d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -80,6 +80,8 @@ list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer) list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op) list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) list(REMOVE_ITEM TEST_OPS test_imperative_resnet) +list(REMOVE_ITEM TEST_OPS test_imperative_resnet_sorted_gradient) +list(REMOVE_ITEM TEST_OPS test_imperative_mnist_sorted_gradient) list(REMOVE_ITEM TEST_OPS test_imperative_se_resnext) list(REMOVE_ITEM TEST_OPS test_imperative_mnist) list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer) @@ -129,8 +131,12 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${G py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS} SERIAL) py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS FLAGS_cudnn_deterministic=1 SERIAL) +py_test_modules(test_imperative_resnet_sorted_gradient MODULES test_imperative_resnet_sorted_gradient ENVS + FLAGS_cudnn_deterministic=1 SERIAL) py_test_modules(test_imperative_mnist MODULES test_imperative_mnist ENVS FLAGS_cudnn_deterministic=1 SERIAL) +py_test_modules(test_imperative_mnist_sorted_gradient MODULES test_imperative_mnist_sorted_gradient ENVS + FLAGS_cudnn_deterministic=1 SERIAL) py_test_modules(test_imperative_se_resnext MODULES test_imperative_se_resnext ENVS FLAGS_cudnn_deterministic=1 SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 8404a57eb8..e7c1f9fda2 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -201,8 +201,21 @@ class TestImperative(unittest.TestCase): ret = fluid.layers.sums(inputs) loss = fluid.layers.reduce_sum(ret) loss.backward() + with fluid.dygraph.guard(): + inputs2 = [] + for _ in range(10): + inputs2.append(fluid.dygraph.base.to_variable(x)) + ret2 = fluid.layers.sums(inputs2) + loss2 = fluid.layers.reduce_sum(ret2) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + loss2.backward(backward_strategy) + self.assertTrue(np.allclose(ret.numpy(), x * 10)) self.assertTrue(np.allclose(inputs[0].gradient(), x)) + self.assertTrue(np.allclose(ret2.numpy(), x * 10)) + a = inputs2[0].gradient() + self.assertTrue(np.allclose(inputs2[0].gradient(), x)) def test_layer(self): with fluid.dygraph.guard(): @@ -291,6 +304,17 @@ class TestImperative(unittest.TestCase): x.backward() dy_grad = l._x_for_debug.gradient() + with fluid.dygraph.guard(): + var_inp2 = fluid.dygraph.base.to_variable(np_inp) + l2 = MyLayer("my_layer") + x2 = l2(var_inp2)[0] + self.assertIsNotNone(x2) + dy_out2 = x2.numpy() + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + x2.backward(backward_strategy) + dy_grad2 = l2._x_for_debug.gradient() + with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[3], append_batch_size=False) @@ -307,6 +331,8 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_grad, static_grad)) + self.assertTrue(np.allclose(dy_out2, static_out)) + self.assertTrue(np.allclose(dy_grad2, static_grad)) def test_mlp(self): np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) @@ -318,6 +344,16 @@ class TestImperative(unittest.TestCase): out.backward() dy_grad = mlp._fc1._w.gradient() + with fluid.dygraph.guard(): + var_inp2 = fluid.dygraph.base.to_variable(np_inp) + mlp2 = MLP("mlp") + out2 = mlp2(var_inp2) + dy_out2 = out2.numpy() + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + out2.backward(backward_strategy) + dy_grad2 = mlp2._fc1._w.gradient() + with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[2, 2], append_batch_size=False) @@ -335,6 +371,8 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_grad, static_grad)) + self.assertTrue(np.allclose(dy_out2, static_out)) + self.assertTrue(np.allclose(dy_grad2, static_grad)) params = mlp.parameters(True) self.assertEqual("mlp/MLP_0/FC_0.w_0", params[0].name) @@ -413,6 +451,19 @@ class TestImperative(unittest.TestCase): dy_grad_h2h = simple_rnn._cell._h2h_w.gradient() dy_grad_i2h = simple_rnn._cell._i2h_w.gradient() + with fluid.dygraph.guard(): + var_inp2 = fluid.dygraph.base.to_variable(np_inp) + var_inp2 = fluid.layers.reshape(var_inp2, shape=[1, 4, 3]) + simple_rnn2 = SimpleRNN("simple_rnn") + outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2) + dy_out2 = outs2[3].numpy() + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + outs2[3].backward(backward_strategy) + dy_grad_h2o2 = simple_rnn2._cell._h2o_w.gradient() + dy_grad_h2h2 = simple_rnn2._cell._h2h_w.gradient() + dy_grad_i2h2 = simple_rnn2._cell._i2h_w.gradient() + with new_program_scope(): inp = fluid.layers.data( name="inp", shape=[1, 4, 3], append_batch_size=False) @@ -427,10 +478,15 @@ class TestImperative(unittest.TestCase): outs[3].name, param_grads[0][1].name, param_grads[1][1].name, param_grads[2][1].name ]) + self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_grad_h2o, static_grad_h2o)) self.assertTrue(np.allclose(dy_grad_h2h, static_grad_h2h)) self.assertTrue(np.allclose(dy_grad_i2h, static_grad_i2h)) + self.assertTrue(np.allclose(dy_out2, static_out)) + self.assertTrue(np.allclose(dy_grad_h2o2, static_grad_h2o)) + self.assertTrue(np.allclose(dy_grad_h2h2, static_grad_h2h)) + self.assertTrue(np.allclose(dy_grad_i2h2, static_grad_i2h)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py index ca2cffa9c7..daf8cc00d4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py @@ -258,7 +258,35 @@ class TestDygraphDeepCF(unittest.TestCase): dy_loss = loss.numpy() sys.stderr.write('dynamic loss: %s %s\n' % (slice, dy_loss)) + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + deepcf2 = DeepCF('deepcf', num_users, num_items, matrix) + adam2 = fluid.optimizer.AdamOptimizer(0.01) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + for e in range(NUM_EPOCHES): + sys.stderr.write('epoch %d\n' % e) + for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): + if slice + BATCH_SIZE >= users_np.shape[0]: + break + prediction2 = deepcf2( + to_variable(users_np[slice:slice + BATCH_SIZE]), + to_variable(items_np[slice:slice + BATCH_SIZE])) + loss2 = fluid.layers.reduce_sum( + fluid.layers.log_loss(prediction2, + to_variable(labels_np[ + slice:slice + BATCH_SIZE]))) + loss2.backward(backward_strategy) + adam2.minimize(loss2) + deepcf2.clear_gradients() + dy_loss2 = loss2.numpy() + sys.stderr.write('dynamic loss: %s %s\n' % + (slice, dy_loss2)) + self.assertEqual(static_loss, dy_loss) + self.assertEqual(static_loss, dy_loss2) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gan.py b/python/paddle/fluid/tests/unittests/test_imperative_gan.py index 5d773ec1c9..7e8cebab44 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gan.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gan.py @@ -170,11 +170,59 @@ class TestDygraphGAN(unittest.TestCase): dy_g_loss = g_loss.numpy() dy_d_loss = d_loss.numpy() + dy_params2 = dict() + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + discriminator2 = Discriminator("d") + generator2 = Generator("g") + sgd2 = SGDOptimizer(learning_rate=1e-3) + + d_real2 = discriminator2(to_variable(np.ones([2, 1], np.float32))) + d_loss_real2 = fluid.layers.reduce_mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=d_real2, label=to_variable(np.ones([2, 1], np.float32)))) + + d_fake2 = discriminator2( + generator2(to_variable(np.ones([2, 2], np.float32)))) + d_loss_fake2 = fluid.layers.reduce_mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=d_fake2, label=to_variable(np.zeros([2, 1], np.float32)))) + + d_loss2 = d_loss_real2 + d_loss_fake2 + d_loss2.backward(backward_strategy) + sgd2.minimize(d_loss2) + discriminator2.clear_gradients() + generator2.clear_gradients() + + d_fake2 = discriminator2( + generator2(to_variable(np.ones([2, 2], np.float32)))) + g_loss2 = fluid.layers.reduce_mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=d_fake2, label=to_variable(np.ones([2, 1], np.float32)))) + g_loss2.backward(backward_strategy) + sgd2.minimize(g_loss2) + for p in discriminator2.parameters(): + dy_params2[p.name] = p.numpy() + for p in generator.parameters(): + dy_params2[p.name] = p.numpy() + + dy_g_loss2 = g_loss2.numpy() + dy_d_loss2 = d_loss2.numpy() + self.assertEqual(dy_g_loss, static_g_loss) self.assertEqual(dy_d_loss, static_d_loss) for k, v in six.iteritems(dy_params): self.assertTrue(np.allclose(v, static_params[k])) + self.assertEqual(dy_g_loss2, static_g_loss) + self.assertEqual(dy_d_loss2, static_d_loss) + for k, v in six.iteritems(dy_params2): + self.assertTrue(np.allclose(v, static_params[k])) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py index 234fcd6040..9ca1b9fd7a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py @@ -132,9 +132,32 @@ class TestDygraphGNN(unittest.TestCase): loss = fluid.layers.reduce_sum(loss) adam = AdamOptimizer(learning_rate=1e-3) adam.minimize(loss) - self.assertEqual(static_loss, loss.numpy()) - self.assertTrue(np.allclose(static_weight, model.gc.weight.numpy())) - sys.stderr.write('%s %s\n' % (static_loss, loss.numpy())) + + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + features2 = np.zeros([1, 100, 50], dtype=np.float32) + # Use selected rows when it's supported. + adj2 = np.zeros([1, 100, 100], dtype=np.float32) + labels2 = np.zeros([100, 1], dtype=np.int64) + + model2 = GCN('test_gcn', 50) + logits2 = model2(to_variable(features2), to_variable(adj2)) + logits2 = fluid.layers.reshape(logits2, logits2.shape[1:]) + # In other example, it's nll with log_softmax. However, paddle's + # log_loss only supports binary classification now. + loss2 = fluid.layers.softmax_with_cross_entropy( + logits2, to_variable(labels2)) + loss2 = fluid.layers.reduce_sum(loss2) + adam2 = AdamOptimizer(learning_rate=1e-3) + adam2.minimize(loss2) + + self.assertEqual(static_loss, loss.numpy()) + self.assertTrue(np.allclose(static_weight, model.gc.weight.numpy())) + self.assertEqual(static_loss, loss2.numpy()) + self.assertTrue(np.allclose(static_weight, model2.gc.weight.numpy())) + sys.stderr.write('%s %s\n' % (static_loss, loss.numpy())) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py new file mode 100644 index 0000000000..0f5eb52e22 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py @@ -0,0 +1,149 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +from test_imperative_mnist import MNIST + + +class TestImperativeMnistSortGradient(unittest.TestCase): + def test_mnist_sort_gradient_float32(self): + seed = 90 + epoch_num = 1 + + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + + mnist2 = MNIST("mnist") + sgd2 = SGDOptimizer(learning_rate=1e-3) + train_reader2 = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128, drop_last=True) + + mnist2.train() + dy_param_init_value2 = {} + for epoch in range(epoch_num): + for batch_id, data in enumerate(train_reader2()): + dy_x_data2 = np.array( + [x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data2 = np.array( + [x[1] for x in data]).astype('int64').reshape(128, 1) + + img2 = to_variable(dy_x_data2) + label2 = to_variable(y_data2) + label2.stop_gradient = True + + cost2 = mnist2(img2) + loss2 = fluid.layers.cross_entropy(cost2, label2) + avg_loss2 = fluid.layers.mean(loss2) + + dy_out2 = avg_loss2.numpy() + + if epoch == 0 and batch_id == 0: + for param in mnist2.parameters(): + dy_param_init_value2[param.name] = param.numpy() + + avg_loss2.backward(backward_strategy) + sgd2.minimize(avg_loss2) + mnist2.clear_gradients() + + dy_param_value2 = {} + for param in mnist2.parameters(): + dy_param_value2[param.name] = param.numpy() + if batch_id == 20: + break + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + + mnist = MNIST("mnist") + sgd = SGDOptimizer(learning_rate=1e-3) + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128, drop_last=True) + + img = fluid.layers.data( + name='pixel', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + cost = mnist(img) + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + sgd.minimize(avg_loss) + + # initialize params and fetch them + static_param_init_value = {} + static_param_name_list = [] + for param in mnist.parameters(): + static_param_name_list.append(param.name) + + out = exe.run(fluid.default_startup_program(), + fetch_list=static_param_name_list) + + for i in range(len(static_param_name_list)): + static_param_init_value[static_param_name_list[i]] = out[i] + + for epoch in range(epoch_num): + for batch_id, data in enumerate(train_reader()): + static_x_data = np.array( + [x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape([128, 1]) + + fetch_list = [avg_loss.name] + fetch_list.extend(static_param_name_list) + out = exe.run( + fluid.default_main_program(), + feed={"pixel": static_x_data, + "label": y_data}, + fetch_list=fetch_list) + + static_param_value = {} + static_out = out[0] + for i in range(1, len(out)): + static_param_value[static_param_name_list[i - 1]] = out[ + i] + if batch_id == 20: + break + + self.assertTrue(np.allclose(dy_x_data2.all(), static_x_data.all())) + + for key, value in six.iteritems(static_param_init_value): + self.assertTrue(np.allclose(value, dy_param_init_value2[key])) + + self.assertTrue(np.allclose(static_out, dy_out2)) + + for key, value in six.iteritems(static_param_value): + self.assertTrue(np.allclose(value, dy_param_value2[key], atol=1e-5)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py new file mode 100644 index 0000000000..d3beed7b00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py @@ -0,0 +1,165 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +from test_imperative_ptb_rnn import PtbModel +import numpy as np +import six + + +class TestDygraphPtbRnnSortGradient(unittest.TestCase): + def test_ptb_rnn_sort_gradient_cpu_float32(self): + seed = 90 + hidden_size = 10 + vocab_size = 1000 + num_layers = 1 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 200 + + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + # TODO: marsyang1993 Change seed to + ptb_model = PtbModel( + "ptb_model", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + sgd = SGDOptimizer(learning_rate=1e-3) + dy_param_updated = dict() + dy_param_init = dict() + dy_loss = None + last_hidden = None + last_cell = None + + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + x = to_variable(x_data) + y = to_variable(y_data) + init_hidden = to_variable(init_hidden_data) + init_cell = to_variable(init_cell_data) + dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, + init_cell) + if i == 0: + for param in ptb_model.parameters(): + dy_param_init[param.name] = param.numpy() + dy_loss.backward(backward_strategy) + sgd.minimize(dy_loss) + ptb_model.clear_gradients() + if i == batch_num - 1: + for param in ptb_model.parameters(): + dy_param_updated[param.name] = param.numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + ptb_model = PtbModel( + "ptb_model", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + sgd = SGDOptimizer(learning_rate=1e-3) + x = fluid.layers.data( + name="x", shape=[-1, num_steps, 1], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') + init_hidden = fluid.layers.data( + name="init_hidden", shape=[1], dtype='float32') + init_cell = fluid.layers.data( + name="init_cell", shape=[1], dtype='float32') + + static_loss, static_last_hidden, static_last_cell = ptb_model( + x, y, init_hidden, init_cell) + sgd.minimize(static_loss) + static_param_updated = dict() + static_param_init = dict() + static_param_name_list = list() + for param in ptb_model.parameters(): + static_param_name_list.append(param.name) + + out = exe.run(framework.default_startup_program(), + fetch_list=static_param_name_list) + for i in range(len(static_param_name_list)): + static_param_init[static_param_name_list[i]] = out[i] + static_loss_value = None + static_last_cell_value = None + static_last_hidden_value = None + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + fetch_list = [static_loss, static_last_hidden, static_last_cell] + fetch_list.extend(static_param_name_list) + out = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "y": y_data, + "init_hidden": init_hidden_data, + "init_cell": init_cell_data + }, + fetch_list=fetch_list) + static_loss_value = out[0] + static_last_hidden_value = out[1] + static_last_cell_value = out[2] + + if i == batch_num - 1: + for k in range(3, len(out)): + static_param_updated[static_param_name_list[k - + 3]] = out[k] + + self.assertTrue(np.array_equal(static_loss_value, dy_loss.numpy())) + self.assertTrue( + np.array_equal(static_last_cell_value, last_cell.numpy())) + self.assertTrue( + np.array_equal(static_last_hidden_value, last_hidden.numpy())) + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue(np.array_equal(value, dy_param_updated[key])) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py new file mode 100644 index 0000000000..77e6fc2734 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py @@ -0,0 +1,230 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +from test_imperative_resnet import ResNet + +batch_size = 8 +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": batch_size, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + }, + "batch_size": batch_size, + "lr": 0.1, + "total_images": 1281164, +} + + +def optimizer_setting(params): + ls = params["learning_strategy"] + if ls["name"] == "piecewise_decay": + if "total_images" not in params: + total_images = 1281167 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + step = int(total_images / batch_size + 1) + + bd = [step * e for e in ls["epochs"]] + base_lr = params["lr"] + lr = [] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.SGD(learning_rate=0.01) + # TODO(minqiyang): Add learning rate scheduler support to dygraph mode + # optimizer = fluid.optimizer.Momentum( + # learning_rate=params["lr"], + # learning_rate=fluid.layers.piecewise_decay( + # boundaries=bd, values=lr), + # momentum=0.9, + # regularization=fluid.regularizer.L2Decay(1e-4)) + + return optimizer + + +class TestDygraphResnetSortGradient(unittest.TestCase): + def test_resnet_sort_gradient_float32(self): + seed = 90 + + batch_size = train_parameters["batch_size"] + batch_num = 20 + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + resnet = ResNet("resnet") + optimizer = optimizer_setting(train_parameters) + np.random.seed(seed) + import random + random.seed = seed + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), + batch_size=batch_size) + + dy_param_init_value = {} + for param in resnet.parameters(): + dy_param_init_value[param.name] = param.numpy() + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + + dy_x_data = np.array( + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + batch_size, 1) + + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + + out = resnet(img) + loss = fluid.layers.cross_entropy(input=out, label=label) + avg_loss = fluid.layers.mean(x=loss) + + dy_out = avg_loss.numpy() + + if batch_id == 0: + for param in resnet.parameters(): + if param.name not in dy_param_init_value: + dy_param_init_value[param.name] = param.numpy() + + avg_loss.backward(backward_strategy) + + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._ivar._grad_ivar().value() + .get_tensor()) + dy_grad_value[param.name + core.grad_var_suffix( + )] = np_array + + optimizer.minimize(avg_loss) + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + + resnet = ResNet("resnet") + optimizer = optimizer_setting(train_parameters) + + np.random.seed(seed) + import random + random.seed = seed + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), + batch_size=batch_size) + + img = fluid.layers.data( + name='pixel', shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = resnet(img) + loss = fluid.layers.cross_entropy(input=out, label=label) + avg_loss = fluid.layers.mean(x=loss) + optimizer.minimize(avg_loss) + + # initialize params and fetch them + static_param_init_value = {} + static_param_name_list = [] + static_grad_name_list = [] + for param in resnet.parameters(): + static_param_name_list.append(param.name) + for param in resnet.parameters(): + if param.trainable: + static_grad_name_list.append(param.name + + core.grad_var_suffix()) + + out = exe.run(fluid.default_startup_program(), + fetch_list=static_param_name_list) + + for i in range(len(static_param_name_list)): + static_param_init_value[static_param_name_list[i]] = out[i] + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + + static_x_data = np.array( + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + [batch_size, 1]) + + fetch_list = [avg_loss.name] + fetch_list.extend(static_param_name_list) + fetch_list.extend(static_grad_name_list) + out = exe.run(fluid.default_main_program(), + feed={"pixel": static_x_data, + "label": y_data}, + fetch_list=fetch_list) + + static_param_value = {} + static_grad_value = {} + static_out = out[0] + param_start_pos = 1 + grad_start_pos = len(static_param_name_list) + param_start_pos + for i in range(param_start_pos, + len(static_param_name_list) + param_start_pos): + static_param_value[static_param_name_list[ + i - param_start_pos]] = out[i] + for i in range(grad_start_pos, + len(static_grad_name_list) + grad_start_pos): + static_grad_value[static_grad_name_list[ + i - grad_start_pos]] = out[i] + + self.assertTrue(np.allclose(static_out, dy_out)) + + self.assertEqual(len(dy_param_init_value), len(static_param_init_value)) + + for key, value in six.iteritems(static_param_init_value): + self.assertTrue(np.allclose(value, dy_param_init_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + self.assertEqual(len(dy_grad_value), len(static_grad_value)) + for key, value in six.iteritems(static_grad_value): + self.assertTrue(np.allclose(value, dy_grad_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + self.assertEqual(len(dy_param_value), len(static_param_value)) + for key, value in six.iteritems(static_param_value): + self.assertTrue(np.allclose(value, dy_param_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py new file mode 100644 index 0000000000..a2664bf0e7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -0,0 +1,367 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +from paddle.fluid import Embedding, LayerNorm, FC, Layer +from paddle.fluid.dygraph import to_variable, guard +from test_imperative_base import new_program_scope +from paddle.fluid import core +from test_imperative_transformer import TransFormer, TrainTaskConfig, ModelHyperParams +import numpy as np +import six +np.set_printoptions(suppress=True) + + +def create_data(is_static=False): + if is_static: + return [ + src_word_np, src_pos_np, src_slf_attn_bias_np, trg_word_np, + trg_pos_np, trg_slf_attn_bias_np, trg_src_attn_bias_np, lbl_word_np, + lbl_weight_np + ] + else: + enc_inputs = [ + to_variable( + src_word_np, name='src_word'), to_variable( + src_pos_np, name='src_pos'), to_variable( + src_slf_attn_bias_np, name='src_slf_attn_bias') + ] + dec_inputs = [ + to_variable( + trg_word_np, name='trg_word'), to_variable( + trg_pos_np, name='trg_pos'), to_variable( + trg_slf_attn_bias_np, name='trg_slf_attn_bias'), + to_variable( + trg_src_attn_bias_np, name='trg_src_attn_bias') + ] + label = to_variable(lbl_word_np, name='lbl_word') + weight = to_variable(lbl_weight_np, name='lbl_weight') + return enc_inputs, dec_inputs, label, weight + + +def create_feed_dict_list(data, init=False): + if init: + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + pos_enc_param_names + else: + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + feed_dict_list = dict() + for i in range(len(data_input_names)): + feed_dict_list[data_input_names[i]] = data[i] + return feed_dict_list + + +def make_all_inputs(input_fields): + """ + Define the input data layers for the transformer model. + """ + inputs = [] + for input_field in input_fields: + input_var = fluid.layers.data( + name=input_field, + shape=input_descs[input_field][0], + dtype=input_descs[input_field][1], + lod_level=input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0, + append_batch_size=False) + inputs.append(input_var) + return inputs + + +# The placeholder for batch_size in compile time. Must be -1 currently to be +# consistent with some ops' infer-shape output in compile time, such as the +# sequence_expand op used in beamsearch decoder. +batch_size = -1 +# The placeholder for squence length in compile time. +seq_len = ModelHyperParams.max_length +# Here list the data shapes and data types of all inputs. +# The shapes here act as placeholder and are set to pass the infer-shape in +# compile time. +input_descs = { + # The actual data shape of src_word is: + # [batch_size, max_src_len_in_batch, 1] + "src_word": [(batch_size, seq_len, 1), "int64", 2], + # The actual data shape of src_pos is: + # [batch_size, max_src_len_in_batch, 1] + "src_pos": [(batch_size, seq_len, 1), "int64"], + # This input is used to remove attention weights on paddings in the + # encoder. + # The actual data shape of src_slf_attn_bias is: + # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] + "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], + # The actual data shape of trg_word is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_word": [(batch_size, seq_len, 1), "int64", + 2], # lod_level is only used in fast decoder. + # The actual data shape of trg_pos is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_pos": [(batch_size, seq_len, 1), "int64"], + # This input is used to remove attention weights on paddings and + # subsequent words in the decoder. + # The actual data shape of trg_slf_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] + "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], + # This input is used to remove attention weights on paddings of the source + # input in the encoder-decoder attention. + # The actual data shape of trg_src_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] + "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], + # This input is used in independent decoder program for inference. + # The actual data shape of enc_output is: + # [batch_size, max_src_len_in_batch, d_model] + "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], + # The actual data shape of label_word is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_word": [(batch_size * seq_len, 1), "int64"], + # This input is used to mask out the loss of paddding tokens. + # The actual data shape of label_weight is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_weight": [(batch_size * seq_len, 1), "float32"], + # This input is used in beam-search decoder. + "init_score": [(batch_size, 1), "float32", 2], + # This input is used in beam-search decoder for the first gather + # (cell states updation) + "init_idx": [(batch_size, ), "int32"], +} + +# Names of word embedding table which might be reused for weight sharing. +word_emb_param_names = ( + "src_word_emb_table", + "trg_word_emb_table", ) +# Names of position encoding table which will be initialized externally. +pos_enc_param_names = ( + "src_pos_enc_table", + "trg_pos_enc_table", ) +# separated inputs for different usages. +encoder_data_input_fields = ( + "src_word", + "src_pos", + "src_slf_attn_bias", ) +decoder_data_input_fields = ( + "trg_word", + "trg_pos", + "trg_slf_attn_bias", + "trg_src_attn_bias", + "enc_output", ) +label_data_input_fields = ( + "lbl_word", + "lbl_weight", ) +# In fast decoder, trg_pos (only containing the current time step) is generated +# by ops and trg_slf_attn_bias is not needed. +fast_decoder_data_input_fields = ( + "trg_word", + "init_score", + "init_idx", + "trg_src_attn_bias", ) +# if we use py_reader +use_py_reader = False + +# if we run sync mode +sync = False + +# how many batches we use +batch_num = 5 + +np.random.seed = 90 +src_word_np = np.random.randint( + 1, + ModelHyperParams.src_vocab_size - 1, + size=(TrainTaskConfig.batch_size, seq_len, 1), + dtype='int64') +src_pos_np = np.random.randint( + 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64') +src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size, + ModelHyperParams.n_head, seq_len, + seq_len).astype('float32') + +trg_word_np = np.random.randint( + 1, + ModelHyperParams.src_vocab_size - 1, + size=(TrainTaskConfig.batch_size, seq_len, 1), + dtype='int64') +trg_pos_np = np.random.randint( + 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64') +trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size, + ModelHyperParams.n_head, seq_len, + seq_len).astype('float32') +trg_src_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size, + ModelHyperParams.n_head, seq_len, + seq_len).astype('float32') + +lbl_word_np = np.random.randint( + 1, + ModelHyperParams.src_vocab_size - 1, + size=(TrainTaskConfig.batch_size * seq_len, 1), + dtype='int64') +lbl_weight_np = np.random.randn(TrainTaskConfig.batch_size * seq_len, + 1).astype('float32') + + +class TestDygraphTransformerSortGradient(unittest.TestCase): + def test_transformer_sort_gradient_float32(self): + seed = 90 + + with guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + transformer = TransFormer( + 'transformer', + ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, + ModelHyperParams.n_head, + ModelHyperParams.d_key, + ModelHyperParams.d_value, + ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, + ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, + ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing, + TrainTaskConfig.label_smooth_eps, + use_py_reader=use_py_reader, + is_test=False) + if sync: + lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( + ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) + with fluid.default_main_program()._lr_schedule_guard(): + learning_rate = lr_decay * TrainTaskConfig.learning_rate + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, + beta1=TrainTaskConfig.beta1, + beta2=TrainTaskConfig.beta2, + epsilon=TrainTaskConfig.eps) + else: + optimizer = fluid.optimizer.SGD(learning_rate=0.003) + dy_param_init = dict() + dy_param_updated = dict() + for i in range(batch_num): + enc_inputs, dec_inputs, label, weights = create_data() + dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer( + enc_inputs, dec_inputs, label, weights) + + if i == 0: + for param in transformer.parameters(): + dy_param_init[param.name] = param.numpy() + + dy_avg_cost.backward(backward_strategy) + optimizer.minimize(dy_avg_cost) + transformer.clear_gradients() + + if i == batch_num - 1: + for param in transformer.parameters(): + dy_param_updated[param.name] = param.numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + transformer = TransFormer( + 'transformer', + ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, + ModelHyperParams.n_head, + ModelHyperParams.d_key, + ModelHyperParams.d_value, + ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, + ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, + ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing, + TrainTaskConfig.label_smooth_eps, + use_py_reader=use_py_reader, + is_test=False) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + optimizer = fluid.optimizer.SGD(learning_rate=0.003) + + data_input_names = encoder_data_input_fields + decoder_data_input_fields[: + -1] + label_data_input_fields + all_inputs = make_all_inputs(data_input_names) + enc_inputs_len = len(encoder_data_input_fields) + dec_inputs_len = len(decoder_data_input_fields[:-1]) + enc_inputs = all_inputs[0:enc_inputs_len] + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + + dec_inputs_len] + label = all_inputs[-2] + weights = all_inputs[-1] + static_param_updated = dict() + static_param_init = dict() + static_param_name_list = list() + static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer( + enc_inputs, dec_inputs, label, weights) + optimizer.minimize(static_avg_cost) + for param in transformer.parameters(): + static_param_name_list.append(param.name) + out = exe.run(fluid.default_startup_program(), + fetch_list=static_param_name_list) + for i in range(len(static_param_name_list)): + static_param_init[static_param_name_list[i]] = out[i] + static_sum_cost_value = None + static_avg_cost_value = None + static_predict_value = None + static_token_num_value = None + for i in range(batch_num): + feed_dict = create_feed_dict_list(create_data(True)) + fetch_list = [ + static_sum_cost, static_avg_cost, static_predict, + static_token_num + ] + + fetch_list.extend(static_param_name_list) + out = exe.run(fluid.default_main_program(), + feed=feed_dict, + fetch_list=fetch_list) + static_sum_cost_value = out[0] + static_avg_cost_value = out[1] + static_predict_value = out[2] + static_token_num_value = out[3] + if i == batch_num - 1: + for k in range(4, len(out)): + static_param_updated[static_param_name_list[k - + 4]] = out[k] + + self.assertTrue( + np.array_equal(static_avg_cost_value, dy_avg_cost.numpy())) + self.assertTrue( + np.array_equal(static_sum_cost_value, dy_sum_cost.numpy())) + self.assertTrue( + np.array_equal(static_predict_value, dy_predict.numpy())) + self.assertTrue( + np.array_equal(static_token_num_value, dy_token_num.numpy())) + + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue(np.array_equal(value, dy_param_updated[key])) + + +if __name__ == '__main__': + unittest.main() -- GitLab