diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 8af1d62dea89343ff2d41dd7c6ac837459df7685..f216aad9d9fd9e7a1d0f811e1fd56eec37b1eedf 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -359,7 +359,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; - CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]); + auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType(); + CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0], + out_dtype); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. @@ -662,13 +664,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name, - ir::Node *out_var_node) const { + ir::Node *out_var_node, proto::VarType::Type dtype) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); + local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx, dtype); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 8e462aec7dc7ce45cad592b89de0b6edde8c9146..1428fe830594ae93ad45d70154efcf68d84f98c4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -68,7 +68,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void CreateScaleLossGradOp(ir::Graph *result, const std::string &loss_grad_name, - ir::Node *out_var_node) const; + ir::Node *out_var_node, + proto::VarType::Type dtype) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index ef1626599795a553e654fe5d3ed74ef3a3a67d78..e1b8e8fe05f0615d689e78d9c405cc5d76d2abb1 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -22,39 +22,66 @@ namespace details { ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, platform::Place place, - platform::DeviceContext *dev_ctx) + platform::DeviceContext *dev_ctx, + proto::VarType::Type dtype) : OpHandleBase(node), coeff_(static_cast(1.0 / num_dev)), scope_(scope), - place_(place) { + place_(place), + out_dtype_(dtype) { this->SetDeviceContext(place_, dev_ctx); } ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} +struct ScaleLossGradFunctor { + float coeff_; + Tensor *out_; + platform::Place place_; + OpHandleBase *op_handle_; + proto::VarType::Type out_dtype_; + platform::DeviceContext *ctx_; + + ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place, + OpHandleBase *op_handle, proto::VarType::Type dtype, + platform::DeviceContext *ctx) + : coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {} + + template + void apply() const { + auto *out_data = out_->mutable_data(place_); + if (platform::is_cpu_place(place_)) { + *out_data = static_cast(coeff_); + } else { +#ifdef PADDLE_WITH_CUDA + OutT cast_coeff = static_cast(coeff_); + auto stream = static_cast(ctx_)->stream(); + memory::Copy(boost::get(place_), out_data, + platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_), + stream); + VLOG(10) << place_ << "RUN Scale loss grad op"; + +#endif + } + } +}; + void ScaleLossGradOpHandle::RunImpl() { // Doesn't wait any event std::string var_name = static_cast(this->outputs_[0])->name_; auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get(); - float *tmp = local_scope.FindVar(var_name) - ->GetMutable() - ->mutable_data(make_ddim({1}), place_); + auto *tensor = local_scope.FindVar(var_name)->GetMutable(); + tensor->Resize(make_ddim({1})); - if (platform::is_cpu_place(place_)) { - *tmp = coeff_; - } else { #ifdef PADDLE_WITH_CUDA - this->RunAndRecordEvent([&] { - auto stream = static_cast( - this->dev_ctxes_.at(place_)) - ->stream(); - memory::Copy(boost::get(place_), tmp, - platform::CPUPlace(), &coeff_, sizeof(float), stream); - VLOG(10) << place_ << "RUN Scale loss grad op"; - }); + ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, + this->dev_ctxes_.at(place_)); + this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); }); +#else + ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr); + framework::VisitDataType(out_dtype_, func); #endif - } } std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; } diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 523b55724c82d4e2bef0520c10e5708c952a3ecc..8bedd1643eb9c5e591fa3c40995fcba08980b9fa 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -26,8 +26,8 @@ namespace details { struct ScaleLossGradOpHandle : public OpHandleBase { ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, - platform::Place place, - platform::DeviceContext *context); + platform::Place place, platform::DeviceContext *context, + proto::VarType::Type dtype); ~ScaleLossGradOpHandle() final; @@ -40,6 +40,7 @@ struct ScaleLossGradOpHandle : public OpHandleBase { float coeff_; Scope *scope_; platform::Place place_; + proto::VarType::Type out_dtype_; }; } // namespace details diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 1a149298fd33f132a90ff5de3b35dd5894a4ae68..ae669f5525443abe424109b6a6869e2ddaf52ba0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( elementwise_div, ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + elementwise_mul, ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); diff --git a/paddle/fluid/operators/fill_zeros_like_op.cu.cc b/paddle/fluid/operators/fill_zeros_like_op.cu.cc index 95381774606b2d8e74519befc9a6f7a3ac20aa45..e80a703c30c0335124c089ea82ba4f6fe055acde 100644 --- a/paddle/fluid/operators/fill_zeros_like_op.cu.cc +++ b/paddle/fluid/operators/fill_zeros_like_op.cu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( @@ -22,4 +23,6 @@ REGISTER_OP_CUDA_KERNEL( ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, ops::FillZerosLikeKernel, + ops::FillZerosLikeKernel, ops::FillZerosLikeKernel); diff --git a/paddle/fluid/operators/metrics/accuracy_op.cu b/paddle/fluid/operators/metrics/accuracy_op.cu index b255d2a7c413b4f965f6b874d342dcb93c7b5e66..4682940f7e15bc8af5dcda24ea058ac7351887c6 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cu +++ b/paddle/fluid/operators/metrics/accuracy_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/metrics/accuracy_op.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/gpu_info.h" namespace paddle { @@ -94,6 +95,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { // FIXME(typhoonzero): types of T is for inference data. // label data is always int64 -REGISTER_OP_CUDA_KERNEL(accuracy, - paddle::operators::AccuracyOpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + accuracy, paddle::operators::AccuracyOpCUDAKernel, + paddle::operators::AccuracyOpCUDAKernel, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/fluid/operators/optimizers/momentum_op.cu b/paddle/fluid/operators/optimizers/momentum_op.cu index 8ce739de8dfd74cb43f9521bf39e3127a8a21925..7f9e7246401bc3c765e539ac4395c4feef3c9508 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cu +++ b/paddle/fluid/operators/optimizers/momentum_op.cu @@ -14,8 +14,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/optimizers/momentum_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel); + ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index 71f079e4d97f5259359ee6572f584894551452ca..f6ef83c3bad23d709b386f8e75bbc97fa9ba0aab 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -237,7 +237,8 @@ class SparseMomentumFunctor { inline HOSTDEVICE void operator()(size_t i) { auto row_idx = math::BinarySearch(rows_, row_height_, i / row_numel_); - T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); // put memory access in register const T p = p_[i]; const T lr = lr_[0]; @@ -282,7 +283,8 @@ class SparseMomentumFunctor { inline HOSTDEVICE void operator()(size_t i) { auto row_idx = math::BinarySearch(rows_, row_height_, i / row_numel_); - T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); // put memory access in register const T p = p_[i]; const T lr = lr_[0]; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 0cad224ca8860b0e4bc2e3f2bc1659235aadfe2d..99a4b1b7b0b33aebd9a1a49b0b771fe6fd134bb3 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -150,7 +151,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(-static_cast(INFINITY), -1); } } if (!(*is_empty)) { @@ -160,7 +161,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, } *max = topk[MaxLength - 1]; - if ((*max).v == -1) *is_empty = true; + if ((*max).v == -static_cast(1)) *is_empty = true; *beam = 0; } } @@ -181,7 +182,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, if (k < MaxLength - *beam) { topk[k] = topk[k + *beam]; } else { - topk[k].set(-INFINITY, -1); + topk[k].set(-static_cast(INFINITY), -1); } } if (!(*is_empty)) { @@ -278,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, bool firststep = true; for (int j = 0; j < MaxLength; j++) { - topk[j].set(-INFINITY, -1); + topk[j].set(-static_cast(INFINITY), -1); } while (top_num) { ThreadGetTopK( @@ -362,5 +363,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index cbb090adefda03717a634dab24132d36d1cfc648..6ce4bf8f13922e2756c3ee8f189bd36123d6964c 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" #define NCCL_ID_VARNAME "NCCLID" @@ -38,6 +39,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { return ncclInt; } else if (type == framework::proto::VarType::INT64) { return ncclInt64; + } else if (type == framework::proto::VarType::FP16) { + return ncclFloat16; } else { PADDLE_THROW("Not supported"); } diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 13d2893fd146b5a3d9100ee1ba6c2243cb9c411b..af02721eb72c1d0f8aa3d7ab8db504c4c33b64d5 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -44,6 +44,8 @@ class DataToLoDTensorConverter(object): self.dtype = 'int64' elif dtype == core.VarDesc.VarType.FP64: self.dtype = 'float64' + elif dtype == core.VarDesc.VarType.FP16: + self.dtype = 'float16' elif dtype == core.VarDesc.VarType.INT32: self.dtype = 'int32' elif dtype == core.VarDesc.VarType.UINT8: diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index b37ebbe5179ba6e36be70ff936cb8a3ca0d89d13..26d1f8f4d2bd67a35c4ec96a025ee273cec4dbd1 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -18,6 +18,7 @@ from . import framework import numpy as np import contextlib from .core import VarDesc +from . import unique_name __all__ = [ 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', @@ -207,16 +208,39 @@ class UniformInitializer(Initializer): # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed + + # to be compatible of fp16 initalizers + if var.dtype == VarDesc.VarType.FP16: + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate(".".join(['gaussian_random', 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + op = block._prepend_op( type="uniform_random", - outputs={"Out": var}, + outputs={"Out": out_var}, attrs={ "shape": var.shape, - "dtype": int(var.dtype), + "dtype": out_dtype, "min": self._low, "max": self._high, "seed": self._seed }) + + if var.dtype == VarDesc.VarType.FP16: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) + var.op = op return op @@ -261,17 +285,39 @@ class NormalInitializer(Initializer): # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed + + # to be compatible of fp16 initalizers + if var.dtype == VarDesc.VarType.FP16: + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate(".".join(['gaussian_random', 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + op = block._prepend_op( type="gaussian_random", - outputs={"Out": var}, + outputs={"Out": out_var}, attrs={ "shape": var.shape, - "dtype": int(var.dtype), + "dtype": out_dtype, "mean": self._mean, "std": self._std_dev, "seed": self._seed, "use_mkldnn": False }) + + if var.dtype == VarDesc.VarType.FP16: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op return op diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index dde05189722fef77e03a1c2d8f3cbae44a3e8245..06039b206b4ddb02e38035134e50b353b987074e 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -63,14 +63,18 @@ def noam_decay(d_model, warmup_steps): Returns: The decayed learning rate. """ - with default_main_program()._lr_schedule_guard(): - global_step = _decay_step_counter(1) - a = global_step**-0.5 - b = (warmup_steps**-1.5) * global_step - lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) + def _lr_schedule(dtype): + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter(1) - return lr_value + a = global_step**-0.5 + b = (warmup_steps**-1.5) * global_step + lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) + + return lr_value + + return _lr_schedule def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -109,15 +113,19 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): sgd_optimizer.minimize(avg_cost) """ - with default_main_program()._lr_schedule_guard(): - global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * (decay_rate**div_res) + def _lr_schedule(dtype): + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * (decay_rate**div_res) - return decayed_lr + return decayed_lr + + return _lr_schedule def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -138,15 +146,19 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): Returns: The decayed learning rate """ - with default_main_program()._lr_schedule_guard(): - global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) - decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) + def _lr_schedule(dtype): + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) + decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res) + + return decayed_lr - return decayed_lr + return _lr_schedule def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): @@ -184,16 +196,20 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): staircase=True)) sgd_optimizer.minimize(avg_cost) """ - with default_main_program()._lr_schedule_guard(): - global_step = _decay_step_counter() - div_res = global_step / decay_steps - if staircase: - div_res = ops.floor(div_res) + def _lr_schedule(dtype): + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() - decayed_lr = learning_rate / (1 + decay_rate * div_res) + div_res = global_step / decay_steps + if staircase: + div_res = ops.floor(div_res) - return decayed_lr + decayed_lr = learning_rate / (1 + decay_rate * div_res) + + return decayed_lr + + return _lr_schedule def polynomial_decay(learning_rate, @@ -224,28 +240,33 @@ def polynomial_decay(learning_rate, Returns: Variable: The decayed learning rate """ - with default_main_program()._lr_schedule_guard(): - global_step = _decay_step_counter() - if cycle: - div_res = ops.ceil(global_step / decay_steps) - zero_var = tensor.fill_constant( - shape=[1], dtype='float32', value=0.0) - one_var = tensor.fill_constant( - shape=[1], dtype='float32', value=1.0) + def _lr_schedule(dtype, decay_steps=decay_steps): + with default_main_program()._lr_schedule_guard(): + global_step = _decay_step_counter() + + if cycle: + div_res = ops.ceil(global_step / decay_steps) + zero_var = tensor.fill_constant( + shape=[1], dtype=dtype, value=0.0) + one_var = tensor.fill_constant( + shape=[1], dtype=dtype, value=1.0) + + with control_flow.Switch() as switch: + with switch.case(global_step == zero_var): + tensor.assign(input=one_var, output=div_res) + decay_steps = decay_steps * div_res + else: + decay_steps_var = tensor.fill_constant( + shape=[1], dtype=dtype, value=float(decay_steps)) + global_step = nn.elementwise_min( + x=global_step, y=decay_steps_var) - with control_flow.Switch() as switch: - with switch.case(global_step == zero_var): - tensor.assign(input=one_var, output=div_res) - decay_steps = decay_steps * div_res - else: - decay_steps_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(decay_steps)) - global_step = nn.elementwise_min(x=global_step, y=decay_steps_var) + decayed_lr = (learning_rate - end_learning_rate) * \ + ((1 - global_step / decay_steps) ** power) + end_learning_rate + return decayed_lr - decayed_lr = (learning_rate - end_learning_rate) * \ - ((1 - global_step / decay_steps) ** power) + end_learning_rate - return decayed_lr + return _lr_schedule def piecewise_decay(boundaries, values): @@ -273,38 +294,42 @@ def piecewise_decay(boundaries, values): """ - with default_main_program()._lr_schedule_guard(): - if len(values) - len(boundaries) != 1: - raise ValueError("len(values) - len(boundaries) should be 1") - - global_step = _decay_step_counter() - - lr = tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - - with control_flow.Switch() as switch: - for i in range(len(boundaries)): - boundary_val = tensor.fill_constant( - shape=[1], - dtype='float32', - value=float(boundaries[i]), - force_cpu=True) - value_var = tensor.fill_constant( - shape=[1], dtype='float32', value=float(values[i])) - with switch.case(global_step < boundary_val): - tensor.assign(value_var, lr) - last_value_var = tensor.fill_constant( + + def _lr_schedule(dtype): + with default_main_program()._lr_schedule_guard(): + if len(values) - len(boundaries) != 1: + raise ValueError("len(values) - len(boundaries) should be 1") + + global_step = _decay_step_counter() + + lr = tensor.create_global_var( shape=[1], + value=0.0, dtype='float32', - value=float(values[len(values) - 1])) - with switch.default(): - tensor.assign(last_value_var, lr) + persistable=True, + name="learning_rate") + + with control_flow.Switch() as switch: + for i in range(len(boundaries)): + boundary_val = tensor.fill_constant( + shape=[1], + dtype='float32', + value=float(boundaries[i]), + force_cpu=True) + value_var = tensor.fill_constant( + shape=[1], dtype='float32', value=float(values[i])) + with switch.case(global_step < boundary_val): + tensor.assign(value_var, lr) + last_value_var = tensor.fill_constant( + shape=[1], + dtype='float32', + value=float(values[len(values) - 1])) + with switch.default(): + tensor.assign(last_value_var, lr) + + return lr - return lr + return _lr_schedule def append_LARS(params_grads, learning_rate, weight_decay): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d8bc919784bf85538ef092b29ecdd8c88ae910d0..4d44ce50a310cc6c95318a159b15544d8628e0bf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2798,6 +2798,10 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() + # use fp32 for bn parameter + if dtype == core.VarDesc.VarType.FP16: + dtype = core.VarDesc.VarType.FP32 + input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2832,7 +2836,7 @@ def batch_norm(input, trainable=False, do_model_average=do_model_average_for_mean_and_var), shape=param_shape, - dtype=input.dtype) + dtype=dtype) mean.stop_gradient = True variance = helper.create_parameter( @@ -2842,7 +2846,7 @@ def batch_norm(input, trainable=False, do_model_average=do_model_average_for_mean_and_var), shape=param_shape, - dtype=input.dtype) + dtype=dtype) variance.stop_gradient = True # create output diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 59c22d4e498814d468c78b10265b7afe35461dfb..58cfc498c9edd77163b2bd4cad2cb991b6f2b20c 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -50,17 +50,21 @@ class Optimizer(object): def __init__(self, learning_rate, regularization=None, name=None): if not isinstance(learning_rate, float) and \ - not isinstance(learning_rate, framework.Variable): - raise TypeError("learning rate should be float or Variable") + not isinstance(learning_rate, framework.Variable) and \ + not callable(learning_rate): + raise TypeError( + "learning rate should be float or Variable or callable(dtype)") self._name = name self.regularization = regularization self._learning_rate = learning_rate # the learning rate type should be inferenced from loss self._dtype = None # each program should have a independent learning rate - # program -> Variable(learning_rate) + # program -> Variable(learning_rate) or: + # program -> callable(return learning_rate Variable) self._learning_rate_map = dict() - if isinstance(self._learning_rate, framework.Variable): + if isinstance(self._learning_rate, framework.Variable) or \ + callable(self._learning_rate): self._learning_rate_map[framework.default_main_program( )] = self._learning_rate # Dictionary of accumulators. Some optimizer subclasses need to @@ -75,6 +79,11 @@ class Optimizer(object): if isinstance(lr, framework.Variable): return + elif callable(lr): + dtype = 'float32' if self._dtype is None else self._dtype + self._learning_rate_map[framework.default_main_program()] = lr( + dtype) + return else: if not isinstance(self._learning_rate, float): raise TypeError( diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 76a707efdc0804be0316ab12c347ffed6199529a..0fe836683b029698b670bbb9f9bb258c2f3b68a0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -368,6 +368,8 @@ class OpTest(unittest.TestCase): place = core.CUDAPlace(0) if core.is_float16_supported(place): return [place] + else: + return [] else: return [] places = [fluid.CPUPlace()] diff --git a/python/paddle/fluid/tests/unittests/test_accuracy_op.py b/python/paddle/fluid/tests/unittests/test_accuracy_op.py index 1b2b53f2d4ce91ae7b5b191ed770b5338f0948c8..5257b0be6f61bc90a6492c44044c122485f4742c 100644 --- a/python/paddle/fluid/tests/unittests/test_accuracy_op.py +++ b/python/paddle/fluid/tests/unittests/test_accuracy_op.py @@ -22,8 +22,10 @@ from op_test import OpTest class TestAccuracyOp(OpTest): def setUp(self): self.op_type = "accuracy" + self.dtype = np.float32 + self.init_dtype() n = 8192 - infer = np.random.random((n, 1)).astype("float32") + infer = np.random.random((n, 1)).astype(self.dtype) indices = np.random.randint(0, 2, (n, 1)) label = np.random.randint(0, 2, (n, 1)) self.inputs = {'Out': infer, 'Indices': indices, "Label": label} @@ -34,14 +36,25 @@ class TestAccuracyOp(OpTest): num_correct += 1 break self.outputs = { - 'Accuracy': np.array([num_correct / float(n)]).astype("float32"), + 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype), 'Correct': np.array([num_correct]).astype("int32"), 'Total': np.array([n]).astype("int32") } + def init_dtype(self): + pass + def test_check_output(self): self.check_output() +class TestAccuracyOpFp16(TestAccuracyOp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-3) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index cadaf1df53af0af56afa8c3631b0f5ce390f318c..15d4db590edc9012604361751e9860ba63239bba 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -21,14 +21,16 @@ from op_test import OpTest class ElementwiseDivOp(OpTest): def setUp(self): self.op_type = "elementwise_div" + self.dtype = np.float32 + self.init_dtype() """ Warning CPU gradient check error! 'X': np.random.random((32,84)).astype("float32"), 'Y': np.random.random((32,84)).astype("float32") """ self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") + 'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype), + 'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) } self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} @@ -46,6 +48,9 @@ class ElementwiseDivOp(OpTest): self.check_grad( ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) + def init_dtype(self): + pass + class TestElementwiseDivOp_scalar(ElementwiseDivOp): def setUp(self): @@ -126,5 +131,21 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): } +class TestElementwiseDivOpFp16(ElementwiseDivOp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=1, no_grad_set=set('Y')) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 57ba34f833f824d13e0b82caea789f7f57622bc9..04840991883229614c1ca4890e5cec2e7ae21084 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -135,5 +135,10 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): } +class TestElementwiseMulOpFp16(ElementwiseMulOp): + def init_dtype(self): + self.dtype = np.float16 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py index eec73d0beb39c49f535a03532e536092001c8445..20f1a110c35d689064c49efba246f078c3badd33 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_zeros_like_op.py @@ -22,12 +22,22 @@ from op_test import OpTest class TestFillZerosLikeOp(OpTest): def setUp(self): self.op_type = "fill_zeros_like" - self.inputs = {'X': np.random.random((219, 232)).astype("float32")} + self.dtype = np.float32 + self.init_dtype() + self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)} self.outputs = {'Out': np.zeros_like(self.inputs["X"])} + def init_dtype(self): + pass + def test_check_output(self): self.check_output() +class TestFillZerosLikeOpFp16(TestFillZerosLikeOp): + def init_dtype(self): + self.dtype = np.float16 + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 0d3e6d73e0149fe633b8f1de9041068c2e3bb293..e34a712d844c2d45f442d04f9350fbd7bc911a2a 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -97,7 +97,7 @@ class TestLearningRateDecay(unittest.TestCase): startup_prog = fluid.Program() with fluid.program_guard(main_prog, startup_prog): - decayed_lr = fluid_decay_fn(**kwargs) + decayed_lr = fluid_decay_fn(**kwargs)("float32") place = fluid.CPUPlace() exe = fluid.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index cf4346cf2e7a099334ec273546901a91d0ad925d..77ec6f9b6bcda7568325698634fd4f86557cd1be 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -24,11 +24,13 @@ from op_test import OpTest class TestMomentumOp1(OpTest): def setUp(self): self.op_type = "momentum" + self.dtype = np.float32 + self.init_dtype() - param = np.random.random((123, 321)).astype("float32") - grad = np.random.random((123, 321)).astype("float32") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") + param = np.random.random((123, 321)).astype(self.dtype) + grad = np.random.random((123, 321)).astype(self.dtype) + velocity = np.zeros((123, 321)).astype(self.dtype) + learning_rate = np.array([0.001]).astype(self.dtype) mu = 0.0001 use_nesterov = False @@ -50,10 +52,21 @@ class TestMomentumOp1(OpTest): self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + def init_dtype(self): + pass + def test_check_output(self): self.check_output() +class TestMomentumOpFp16(TestMomentumOp1): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-3) + + class TestMomentumOp2(OpTest): '''Test Momentum with default values for attributes ''' diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index 69b29db83a43d18c0825b610642009a0377b9901..21b5a62baf96bfb2d76a8c59133e8f5d1cb35aea 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -23,8 +23,11 @@ class TestTopkOp(OpTest): def setUp(self): self.set_args() self.op_type = "top_k" + self.dtype = np.float32 + self.init_dtype() + k = self.top_k - input = np.random.random((self.row, k)).astype("float32") + input = np.random.random((self.row, k)).astype(self.dtype) output = np.ndarray((self.row, k)) indices = np.ndarray((self.row, k)).astype("int64") @@ -38,6 +41,9 @@ class TestTopkOp(OpTest): self.outputs = {'Out': output, 'Indices': indices} + def init_dtype(self): + pass + def set_args(self): self.row = 32 self.top_k = 1 @@ -46,6 +52,11 @@ class TestTopkOp(OpTest): self.check_output() +class TestTopkOpFp16(TestTopkOp): + def init_dtype(self): + self.dtype = np.float16 + + class TestTopkOp3d(OpTest): def setUp(self): self.op_type = "top_k"