diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index e2c0e55ef8656fff38314c8b646b942ce2c3db25..7a0c93eb1b2eaa7afaae7f0a604a0da5ac0fd75d 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -11,31 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include #include -#include #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fill_constant_op.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -namespace details { -template -struct RandomDistributionType { - using Type = T; -}; - -template <> -struct RandomDistributionType { - using Type = float; -}; -} // namespace details - template struct GaussianGenerator { T mean_, std_; @@ -49,16 +34,12 @@ struct GaussianGenerator { : mean_(mean), std_(std), seed_(seed), offset_(offset) {} __host__ __device__ T operator()(const unsigned int n) const { - using DataType = typename details::RandomDistributionType::Type; - thrust::minstd_rand rng; rng.seed(seed_); - thrust::normal_distribution dist(static_cast(mean_), - static_cast(std_)); + thrust::normal_distribution dist(mean_, std_); unsigned int new_n = n + offset_; rng.discard(new_n); - T out = static_cast(dist(rng)); - return out; + return dist(rng); } }; @@ -141,13 +122,10 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL( - gaussian_random, paddle::operators::GPUGaussianRandomKernel, - paddle::operators::GPUGaussianRandomKernel, - paddle::operators::GPUGaussianRandomKernel); +REGISTER_OP_CUDA_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel, + paddle::operators::GPUGaussianRandomKernel); REGISTER_OP_CUDA_KERNEL( gaussian_random_batch_size_like, paddle::operators::GPUGaussianRandomBatchSizeLikeKernel, - paddle::operators::GPUGaussianRandomBatchSizeLikeKernel, - paddle::operators::GPUGaussianRandomBatchSizeLikeKernel< - paddle::platform::float16>); + paddle::operators::GPUGaussianRandomBatchSizeLikeKernel); diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index cec574a6adafbfff7ce8c4399773a62a25faa37c..563a6c165b748543516eabbcdb0e1c8b9be8a44d 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -164,12 +163,9 @@ class GPUUniformRandomKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL( - uniform_random, paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); -REGISTER_OP_CUDA_KERNEL( - uniform_random_batch_size_like, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel, - paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); +REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like, + paddle::operators::GPUUniformRandomKernel, + paddle::operators::GPUUniformRandomKernel); diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 79291e2da882a4d21c0bbfc4d1ccd732f0057f82..496eb78f20ef7bd25db07f68bb15202b2f7f2972 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -131,10 +131,6 @@ struct PADDLE_ALIGN(2) float16 { #endif } - HOSTDEVICE inline float16(int32_t val) : float16(static_cast(val)) {} - - HOSTDEVICE inline float16(uint32_t val) : float16(static_cast(val)) {} - HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} template diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index f60e66a37cfe6cabdb0e1154b9201bfaeb757566..1cf93fa7e1920d89e43a63051254f71fc328d06e 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -267,35 +267,18 @@ def cast_net_to_fp16(program): op._set_attr('dtype', core.VarDesc.VarType.FP16) -def cast_parameters_to_fp16(program): +def cast_parameters_to_fp16(exe, program, scope=None): + exe_scope = scope if scope is not None else global_scope() global_block = program.global_block() all_parameters = global_block.all_parameters() - is_bn_params = lambda param: (param.name.find('bn') != -1 and (param.name.endswith('_offset') or param.name.endswith('_mean') or param.name.endswith('_scale') or param.name.endswith('_variance'))) - all_param_names = {p.name for p in all_parameters if not is_bn_params(p)} - ops = global_block.ops - for param in all_parameters: - if param.name in all_param_names: - param_var = global_block.var(param.name) - if param_var.dtype == core.VarDesc.VarType.FP32: - param_var.desc.set_dtype(core.VarDesc.VarType.FP16) - - for op in ops: - target_op = False - for out_name in op.output_names: - for out_var_name in op.output(out_name): - if out_var_name in all_param_names: - target_op = True - if target_op: - if op.has_attr('in_dtype') and op.attr( - 'in_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('in_dtype', core.VarDesc.VarType.FP16) - if op.has_attr('out_dtype') and op.attr( - 'out_dtype') == core.VarDesc.VarType.FP32: - op._set_attr('out_dtype', core.VarDesc.VarType.FP16) - if op.has_attr('dtype') and op.attr( - 'dtype') == core.VarDesc.VarType.FP32: - op._set_attr('dtype', core.VarDesc.VarType.FP16) + if not (param.name.find('bn') != -1 and + (param.name.endswith('_offset') or param.name.endswith('_mean') + or param.name.endswith('_scale') or + param.name.endswith('_variance'))): + param_t = exe_scope.find_var(param.name).get_tensor() + data = np.array(param_t) + param_t.set(np.float16(data), exe.place) def rewrite_program(main_prog, amp_lists):