diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 373d292b443b7651b785a52a6986b0a0be58ad61..1fdb64fd0db1ae153d645e637d8bc0a563a3c95c 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(layer SRCS layer.cc DEPS proto_desc operator) -cc_library(tracer SRCS tracer.cc DEPS proto_desc) +cc_library(layer SRCS layer.cc DEPS proto_desc operator device_context blas) +cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context) cc_library(engine SRCS engine.cc) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 7594670cd2608802bdf41682ef5724a7a965d754..ffe276abb25d2b69ffbfdb9a0b3a650d5be556d5 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/imperative/layer.h" + #include #include #include @@ -22,6 +23,9 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/string/printf.h" namespace paddle { @@ -31,22 +35,68 @@ std::map py_funcs_; using framework::Variable; -void AddTo(Variable* src, Variable* dst) { - framework::LoDTensor* dst_tensor = dst->GetMutable(); - framework::LoDTensor* src_tensor = src->GetMutable(); +namespace detail { + +template +class TensorAddToFunctor : public boost::static_visitor<> { + public: + TensorAddToFunctor(int64_t numel, const T* x, T* y) + : numel_(numel), x_(x), y_(y) {} + + void operator()(const platform::CPUPlace& place) { + platform::CPUDeviceContext* ctx = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto blas = + operators::math::GetBlas(*ctx); + blas.AXPY(numel_, 1., x_, y_); + } + +#ifdef PADDLE_WITH_CUDA + void operator()(const platform::CUDAPlace& place) { + platform::CUDADeviceContext* ctx = + dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto blas = + operators::math::GetBlas(*ctx); + blas.AXPY(numel_, 1., x_, y_); + } +#else + void operator()(const platform::CUDAPlace& place) { + PADDLE_THROW("Do NOT support gradient merge in place %s", place); + } +#endif + + // there is NO blas in CUDAPinnedPlace + void operator()(const platform::CUDAPinnedPlace& place) { + PADDLE_THROW("Do NOT support gradient merge in place %s", place); + } + + private: + int64_t numel_; + const T* x_; + T* y_; +}; + +} // namespace detail + +void AddGradTo(Variable* src, Variable* dst, platform::Place place) { + framework::Tensor* dst_tensor = dst->GetMutable(); + framework::Tensor* src_tensor = src->GetMutable(); + // FIXME(minqiyang): loss_grad op will pass a zero grad of label // ugly fix for it if (src_tensor->numel() == 0) { return; } + PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), src_tensor->numel()); - float* dst_data = dst_tensor->mutable_data(platform::CPUPlace()); - const float* src_data = src_tensor->data(); - for (int64_t i = 0; i < src_tensor->numel(); ++i) { - dst_data[i] += src_data[i]; - } + + detail::TensorAddToFunctor func( + src_tensor->numel(), src_tensor->data(), + dst_tensor->mutable_data(place)); + boost::apply_visitor(func, place); } class Autograd { @@ -158,7 +208,7 @@ std::map> OpBase::ApplyGrad() { PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); framework::Scope scope; - platform::CPUPlace place; + platform::Place place = expected_place_; PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); p.op.RuntimeInferShape(scope, place, ctx); p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); @@ -172,7 +222,7 @@ std::map> OpBase::ApplyGrad() { for (size_t i = 0; i < outputs.size(); ++i) { framework::Variable* grad = outputs[i]; framework::Variable* orig_grad = origin_outputs[i]; - AddTo(grad, orig_grad); + AddGradTo(grad, orig_grad, expected_place_); delete grad; } } @@ -184,8 +234,10 @@ void VarBase::RunBackward() { VLOG(3) << "start backward"; auto grads_t = grads_->var_->GetMutable(); - float* data = grads_t->mutable_data(platform::CPUPlace()); - std::fill(data, data + grads_t->numel(), 1.0); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + var_->GetMutable()->place())), + grads_t, 1.0); PADDLE_ENFORCE( grads_ == diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index daf56a521085b63926194b958094a7d170873830..5a1ad5540806b29ad1cb789d9bdcb45fbc42c134 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -26,12 +26,15 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { +class VarBase; + namespace py = ::pybind11; class PreparedOp { @@ -81,6 +84,8 @@ class PreparedOp { return PreparedOp(op, ctx, kernel_iter->second, dev_ctx); } + inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx; } + const framework::OperatorBase& op; const framework::RuntimeContext& ctx; framework::OperatorWithKernel::OpKernelFunc func; @@ -159,7 +164,8 @@ class OpBase { : op_desc_(nullptr), forward_id_(-1), grad_op_desc_(nullptr), - backward_id_(-1) {} + backward_id_(-1), + expected_place_(platform::CPUPlace()) {} virtual ~OpBase() { if (grad_op_desc_) delete grad_op_desc_; @@ -176,6 +182,8 @@ class OpBase { framework::OpDesc* grad_op_desc_; int backward_id_; + platform::Place expected_place_; + VarBasePtrMap input_vars_; VarBasePtrMap output_vars_; OpBasePtrMap pre_ops_; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index a01225ccee4a82f77ec2a23df75d1cf7b719bdb7..0c7e69cc0be40ba89ee0aa4a45e91ee7e7008462 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -14,6 +14,10 @@ #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace imperative { @@ -31,16 +35,38 @@ void CreateGradOp(const framework::OpDesc& op_desc, *grad_op_desc = grad_op_descs[0].release(); } -void InitVar(framework::Variable* var, framework::Variable* grad_var) { +void InitVar(framework::Variable* var, framework::Variable* grad_var, + platform::DeviceContext* dev_ctx) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx, + "Could not get valid device from forward op"); auto& var_t = var->Get(); - float* data = - grad_var->GetMutable()->mutable_data( - var_t.dims(), platform::CPUPlace()); - std::fill(data, data + var_t.numel(), 0.0); + grad_var->GetMutable()->mutable_data( + var_t.dims(), dev_ctx->GetPlace()); + operators::math::set_constant( + *dev_ctx, grad_var->GetMutable(), .0f); +} + +platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { + platform::Place result = place; + for (auto it : inputs) { + for (VarBase* var : it.second) { + platform::Place tmp_place = + var->var_->Get().place(); + if (!platform::is_same_place(tmp_place, result)) { + PADDLE_THROW( + "Input variable should keep in the same place: %s, but get place: " + "%s of input %s instead", + result, tmp_place, it.first); + } + } + } + + return result; } void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, const VarBasePtrMap& outputs, framework::BlockDesc* block, + const platform::Place expected_place, const bool stop_gradient) { std::map vars; @@ -108,10 +134,12 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); framework::Scope scope; - platform::CPUPlace place; - PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); - p.op.RuntimeInferShape(scope, place, ctx); - p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + op->expected_place_ = GetExpectedPlace(expected_place, inputs); + PreparedOp prepared_op = + PreparedOp::Prepare(ctx, *op_kernel, op->expected_place_); + prepared_op.op.RuntimeInferShape(scope, op->expected_place_, ctx); + prepared_op.func(framework::ExecutionContext( + prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); if (!stop_gradient) { framework::OpDesc* grad_op_desc; @@ -134,7 +162,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, } else { VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + InitVar(var->var_, var->grads_->var_, + prepared_op.GetDeviceContext()); } // Douts. grad_in_vars.push_back(var->grads_->var_); @@ -147,10 +176,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, for (const std::string& grad_outvar : it.second) { block->FindRecursiveOrCreateVar(grad_outvar); auto var_it = grad_to_var->find(grad_outvar); - PADDLE_ENFORCE(var_it != grad_to_var->end()); + PADDLE_ENFORCE(var_it != grad_to_var->end(), + "Could not found the grad op output var, should this " + "operator %s's stop gradient be True", + op_desc->Type()); VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext()); } grad_out_vars.push_back(var->grads_->var_); } @@ -193,16 +225,23 @@ std::vector Tracer::PyTrace(OpBase* op, for (VarBase* out : outputs) { grad_input_vars.push_back(out->var_); } + + platform::CPUPlace place; for (VarBase* out : outputs) { grad_input_vars.push_back(out->grads_->var_); if (!grad_input_vars.back()->IsInitialized()) { - InitVar(out->var_, grad_input_vars.back()); + // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now + InitVar(out->var_, grad_input_vars.back(), + platform::DeviceContextPool::Instance().Get(place)); } } + for (const VarBase* inp : inputs) { grad_output_vars.push_back(inp->grads_->var_); if (!grad_output_vars.back()->IsInitialized()) { - InitVar(inp->var_, grad_output_vars.back()); + // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now + InitVar(inp->var_, grad_output_vars.back(), + platform::DeviceContextPool::Instance().Get(place)); } } } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f225d8abe6c0635d2bdd8dba0b12c7fc3a4110db..690838215581b09ff35a0ea13f30655b77e6e187 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace imperative { @@ -34,21 +35,25 @@ void CreateGradOp(const framework::OpDesc& op_desc, void InitVar(framework::Variable* var, framework::Variable* grad_var); +platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); + class Tracer { public: explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} virtual ~Tracer() {} - void Trace(OpBase* op, - const std::map>& inputs, - const std::map>& outputs, - framework::BlockDesc* block, const bool stop_gradient = false); + void Trace(OpBase* op, const VarBasePtrMap& inputs, + const VarBasePtrMap& outputs, framework::BlockDesc* block, + const platform::Place expected_place, + const bool stop_gradient = false); std::vector PyTrace(OpBase* op, const std::vector& inputs, bool stop_gradient = false); private: + platform::Place GetPlace(const VarBasePtrMap& inputs); + framework::BlockDesc* root_block_; }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 8f80a2d7822f1dc16cee2514a991b7341f5d1cfd..2493fb71c019f9923012afa4a46cb3e95479f860 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -30,8 +30,9 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( - "'Place' is not supported, Please re-compile with WITH_GPU " - "option"); + "Place %s is not supported, Please re-compile with WITH_GPU " + "option", + place); } return it->second.get().get(); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index dbc7843caa0c0a39a32cda6050fa99a3ab4c3e22..31c3bfa43ffec22059a602e9ff09a33188d72c91 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -15,18 +15,38 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace pybind { // Bind Methods -void BindTracer(pybind11::module *m) { +void BindTracer(pybind11::module* m) { pybind11::class_(*m, "Tracer", "") .def("__init__", - [](imperative::Tracer &self, framework::BlockDesc *root_block) { + [](imperative::Tracer& self, framework::BlockDesc* root_block) { new (&self) imperative::Tracer(root_block); }) - .def("trace", &imperative::Tracer::Trace) + .def("trace", + [](imperative::Tracer& self, imperative::OpBase* op, + const imperative::VarBasePtrMap& inputs, + const imperative::VarBasePtrMap& outputs, + framework::BlockDesc* block, + const platform::CPUPlace expected_place, + const bool stop_gradient = false) { + self.Trace(op, inputs, outputs, block, expected_place, + stop_gradient); + }) + .def("trace", + [](imperative::Tracer& self, imperative::OpBase* op, + const imperative::VarBasePtrMap& inputs, + const imperative::VarBasePtrMap& outputs, + framework::BlockDesc* block, + const platform::CUDAPlace expected_place, + const bool stop_gradient = false) { + self.Trace(op, inputs, outputs, block, expected_place, + stop_gradient); + }) .def("py_trace", &imperative::Tracer::PyTrace, pybind11::return_value_policy::take_ownership); } diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8d061f41f09a88d06a6b0018d95793e8cadbcdf3..012ceafe1e5b759fcbd44a9256c73c7b7e5caa63 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -66,6 +66,7 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _imperative_tracer_ = None +_current_expected_place_ = None def _in_imperative_mode(): @@ -76,6 +77,10 @@ def _imperative_tracer(): return _imperative_tracer_ +def _current_expected_place(): + return _current_expected_place_ + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() @@ -1299,7 +1304,7 @@ class Block(object): def _trace_op(self, op, stop_gradient=False): if _in_imperative_mode(): _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc, - stop_gradient) + _current_expected_place_, stop_gradient) def _insert_op(self, index, *args, **kwargs): """ @@ -2312,9 +2317,16 @@ def _get_var(name, program=None): @contextlib.contextmanager -def _imperative_guard(tracer): +def _imperative_guard(tracer, place): global _imperative_tracer_ tmp_trace = _imperative_tracer_ _imperative_tracer_ = tracer + + global _current_expected_place_ + tmp_place = _current_expected_place_ + _current_expected_place_ = place + yield + _imperative_tracer_ = tmp_trace + _current_expected_place_ = tmp_place diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py index 5d3ebb25a935cea6ec376e6bc044281dcba37337..83789dbe608147fc4e219f13bc4a565fc4cacdf3 100644 --- a/python/paddle/fluid/imperative/base.py +++ b/python/paddle/fluid/imperative/base.py @@ -25,17 +25,28 @@ def enabled(): @contextlib.contextmanager -def guard(): +def guard(device=0): train = framework.Program() startup = framework.Program() tracer = core.Tracer(train.current_block().desc) + + if device is None: + place = core.CPUPlace() + else: + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(device) + else: + place = core.CPUPlace() + with framework.program_guard(train, startup): with framework.unique_name.guard(): - with framework._imperative_guard(tracer): + with framework._imperative_guard(tracer, place): yield def to_variable(value, block=None): + assert enabled(), "to_variable could only be called in imperative mode" + if isinstance(value, np.ndarray): if not block: block = framework.default_main_program().current_block() @@ -47,9 +58,7 @@ def to_variable(value, block=None): dtype=value.dtype) var = py_var._ivar.value() tensor = var.get_tensor() - tensor.set(value, core.CPUPlace()) + tensor.set(value, framework._current_expected_place()) return py_var elif isinstance(value, framework.Variable): return value - else: - raise ValueError("Unsupported type %s" % type(value)) diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 72d6c20bc644b57c5d1b8188a3f2381d9c6140b9..6528de9a95fd3df3bc469764ca3b70bf48356df5 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -252,15 +252,15 @@ class FC(layers.Layer): "y_num_col_dims": 1 }) - out = self._helper.create_variable_for_type_inference(self._dtype) + pre_bias = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="sum", inputs={"X": [tmp]}, - outputs={"Out": out}, + outputs={"Out": pre_bias}, attrs={"use_mkldnn": False}) pre_activation = self._helper.append_bias_op( - pre_bias, dim_start=num_flatten_dims) + pre_bias, dim_start=self._num_flatten_dims) return self._helper.append_activation(pre_activation) @@ -355,11 +355,11 @@ class BatchNorm(layers.Layer): variance_out = self._variance saved_mean = self._helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) + dtype=self._dtype, stop_gradient=True) saved_variance = self._helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True) + dtype=self._dtype, stop_gradient=True) batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( - dtype) + self._dtype) self._helper.append_op( type="batch_norm", diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index dde05189722fef77e03a1c2d8f3cbae44a3e8245..617704a53138bd081a2ebe318de0c89e8db4aa96 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -321,7 +321,7 @@ def append_LARS(params_grads, learning_rate, weight_decay): The decayed learning rate Examples: .. code-block:: python - + learning_rate *= local_gw_ratio * sqrt(sumsq(param)) / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param))) """ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 235a1556e749b9925b4574ba33669021af49f193..f624dad376547e67627ca1bac68ccdde97f0de53 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5810,7 +5810,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): type='increment', inputs={'X': [counter]}, outputs={'Out': [counter]}, - attrs={'step': float(step)}) + attrs={'step': float(step)}, + stop_gradient=True) counter.stop_gradient = True return counter diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index ce9f508c9f10981d62b7f8417080f0f3b8d9b8a7..2153ca254f0e286a77160a2d53473e1bc76109d5 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -382,7 +382,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): 'dtype': out.dtype, 'value': float(value), 'force_cpu': force_cpu or force_init_on_cpu() - }) + }, + stop_gradient=True) out.stop_gradient = True return out diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f01a0eda9a711abb3265fe5bb86ecb702a6ac6aa..449eaa0970cc7f0e1d5291f21c084f027e9a13ed 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -301,10 +301,10 @@ class Optimizer(object): no_grad_set (set|None): set of Variables should be ignored. callbacks (list|None): list of callables to run when appending backward operator for one parameter. - + Return: list: list of (param, grad) pair, grad is the output of backward. - + Examples: See examples in `apply_gradients`. """ @@ -322,10 +322,10 @@ class Optimizer(object): Args: params_grads (list): list of (param, grad) pair to do optimization. - + Returns: list: A list of operators appended to the current program. - + Examples: .. code-block:: python @@ -364,7 +364,7 @@ class Optimizer(object): This method combines interface `backward()` and `apply_gradients()` into one. - + Args: loss (Variable): loss variable to run optimizations. startup_program (Program): startup_program for initializing parameters @@ -381,18 +381,19 @@ class Optimizer(object): optimize_ops = [] if imperative_base.enabled(): if parameter_list is not None: - params_grads = parameter_list + parameters = parameter_list else: parameters = program.global_block().all_parameters() - params_grads = [] - for param in parameters: - # create gradient variable - grad_var = Variable( - block=loss.block, - name=param._ivar._grad_name(), - stop_gradient=True, - ivar=param._ivar._grad_ivar()) - params_grads.append((param, grad_var)) + + params_grads = [] + for param in parameters: + # create gradient variable + grad_var = Variable( + block=loss.block, + name=param._ivar._grad_name(), + stop_gradient=True, + ivar=param._ivar._grad_ivar()) + params_grads.append((param, grad_var)) with program_guard(program, startup_program): optimize_ops = self._create_optimization_pass(params_grads) else: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ec8b19c7ba07a9e57a32277ff3fc34b0ea25a819..6360951503e6f47b5d21367b436227a70c6cddd3 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -107,7 +107,6 @@ if(WITH_DISTRIBUTE) endif() py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) -set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 150) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) if(NOT APPLE) py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index 86baff3c589d7b8a14938886b3e2104b0beb1cc9..e9aaddb00f841bb37bee605099847a66c39e8d33 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -82,7 +82,7 @@ class MLP(fluid.imperative.Layer): class TestImperative(unittest.TestCase): def test_layer(self): - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): cl = core.Layer() cl.forward([]) l = fluid.imperative.Layer() @@ -90,7 +90,7 @@ class TestImperative(unittest.TestCase): def test_pylayer_func_id(self): - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): class PyLayer1(fluid.imperative.PyLayer): def __init__(self): @@ -130,7 +130,7 @@ class TestImperative(unittest.TestCase): def test_pylayer(self): np_inp = np.ones([2, 2], np.float32) - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): my_py_layer = MyPyLayer() var_inp = fluid.imperative.base.to_variable(np_inp) outs = my_py_layer(var_inp) @@ -158,7 +158,7 @@ class TestImperative(unittest.TestCase): def test_layer_in_out(self): np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32) - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): var_inp = fluid.imperative.base.to_variable(np_inp) l = MyLayer() x = l(var_inp)[0] @@ -185,7 +185,7 @@ class TestImperative(unittest.TestCase): def test_mlp(self): np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): var_inp = fluid.imperative.base.to_variable(np_inp) mlp = MLP() out = mlp(var_inp) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 63eeae4b712c2064309b664b91d5f0347b67817d..34d1654c280a6c801acb046a25c3fe275cc9c79a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -101,7 +101,7 @@ class TestImperativeMnist(unittest.TestCase): def test_mnist_cpu_float32(self): seed = 90 - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 4bf80afd49b5221ab530a46ba997cc08288332f7..594b7519855356df05eb2ef7b59e30ccdabe4ab9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -34,7 +34,10 @@ train_parameters = { "batch_size": 256, "epochs": [30, 60, 90], "steps": [0.1, 0.01, 0.001, 0.0001] - } + }, + "batch_size": 256, + "lr": 0.1, + "total_images": 1281164, } @@ -52,24 +55,33 @@ def optimizer_setting(params): base_lr = params["lr"] lr = [] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=bd, values=lr), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer = fluid.optimizer.SGD(learning_rate=params["lr"]) + # optimizer = fluid.optimizer.Momentum( + # learning_rate=params["lr"], + # learning_rate=fluid.layers.piecewise_decay( + # boundaries=bd, values=lr), + # momentum=0.9, + # regularization=fluid.regularizer.L2Decay(1e-4)) return optimizer class ConvBNLayer(fluid.imperative.Layer): - def __init__(self, num_filters, filter_size, stride=1, groups=1, act=None): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): super(ConvBNLayer, self).__init__() self._conv = Conv2D( - 3, - num_filters, - filter_size, - stride, (filter_size - 1) // 2, + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, groups=groups, act=None, bias_attr=None) @@ -84,36 +96,54 @@ class ConvBNLayer(fluid.imperative.Layer): class BottleneckBlock(fluid.imperative.Layer): - def __init__(self, num_filters, stride, shortcut=False): + def __init__(self, num_channels, num_filters, stride, shortcut=True): super(BottleneckBlock, self).__init__() self.conv0 = ConvBNLayer( - num_filters=num_filters, filter_size=1, act='relu') + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu') self.conv1 = ConvBNLayer( - num_filters=num_filters, filter_size=3, stride=stride, act='relu') + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') self.conv2 = ConvBNLayer( - num_filters=num_filters * 4, filter_size=1, act=None) + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act=None) - if shortcut: + if not shortcut: self.short = ConvBNLayer( - num_filters=num_filters * 4, filter_size=1, stride=stride) + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride) self.shortcut = shortcut + self._num_channels_out = num_filters * 4 + def forward(self, inputs): - self.conv0() - self.conv1() - self.conv2() + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) if self.shortcut: - self.short() + short = inputs + else: + short = self.short(inputs) - return fluid.layers.elementwise_add( - x=self.short, y=self.conv2, act='relu') + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') class ResNet(fluid.imperative.Layer): def __init__(self, layers=50, class_dim=1000): + super(ResNet, self).__init__() + self.layers = layers supported_layers = [50, 101, 152] assert layers in supported_layers, \ @@ -128,20 +158,23 @@ class ResNet(fluid.imperative.Layer): num_filters = [64, 128, 256, 512] self.conv = ConvBNLayer( - num_filters=64, filter_size=7, stride=2, act='relu') + num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') self.pool2d_max = Pool2D( pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') self.bottleneck_block_list = [] + num_channels = 64 for block in range(len(depth)): - shortcut = True + shortcut = False for i in range(depth[block]): bottleneck_block = BottleneckBlock( + num_channels=num_channels, num_filters=num_filters[block], stride=2 if i == 0 and block != 0 else 1, shortcut=shortcut) + num_channels = bottleneck_block._num_channels_out self.bottleneck_block_list.append(bottleneck_block) - shortcut = False + shortcut = True self.pool2d_avg = Pool2D( pool_size=7, pool_type='avg', global_pooling=True) @@ -160,12 +193,12 @@ class ResNet(fluid.imperative.Layer): for bottleneck_block in self.bottleneck_block_list: y = bottleneck_block(y) y = self.pool2d_avg(y) - y = self.out() + y = self.out(y) return y class TestImperativeResnet(unittest.TestCase): - def test_resnet_cpu_float32(self): + def test_resnet_gpu_float32(self): seed = 90 with fluid.imperative.guard(): @@ -183,17 +216,17 @@ class TestImperativeResnet(unittest.TestCase): break x_data = np.array( - [x[0].reshape(1, 28, 28) for x in data]).astype('float32') + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') y_data = np.array([x[1] for x in data]).astype('int64').reshape( - 128, 1) + 256, 1) img = to_variable(x_data) label = to_variable(y_data) label._stop_gradient = True - cost = resnet(img) + out = resnet(img) loss = fluid.layers.cross_entropy(input=out, label=label) - avg_loss = fluid.layers.mean(x=cost) + avg_loss = fluid.layers.mean(x=loss) dy_out = avg_loss._numpy() if batch_id == 0: