From 6514f52e46e582314a84dd22f40db9d2cbb2e260 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 25 Nov 2019 19:51:37 +0800 Subject: [PATCH] fix the fill_constant op precious problem (#21322) * fix the fill_constant op precious problem test=develop --- paddle/fluid/operators/fill_constant_op.cc | 6 +++- paddle/fluid/operators/fill_constant_op.h | 42 +++++++++++++++++----- python/paddle/fluid/layers/tensor.py | 5 +++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 5850919cb3c..bfa9182b986 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -90,8 +90,12 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { "The shape of the element in vector must be [1].") .AsDuplicable() .AsDispensable(); - AddAttr("value", "(float, default 0) The value to be filled") + AddAttr("value", "(float, default 0.0f) The value to be filled") .SetDefault(0.0f); + AddAttr( + "str_value", + "(string, default empty) The str convert to value to be filled") + .SetDefault(""); AddAttr("force_cpu", "(bool, default false) Force fill output variable to cpu " "memory. Otherwise, fill output variable to the running " diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 1359d25df70..a972ff21173 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -14,8 +14,9 @@ limitations under the License. */ #pragma once +#include +#include #include - #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -75,13 +76,28 @@ class FillConstantKernel : public framework::OpKernel { void Compute(const paddle::framework::ExecutionContext &ctx) const override { auto data_type = static_cast(ctx.Attr("dtype")); - auto value = ctx.Attr("value"); + auto str_value = ctx.Attr("str_value"); + auto float_value = ctx.Attr("value"); auto force_cpu = ctx.Attr("force_cpu"); - framework::Tensor *tensor = nullptr; framework::Variable *out_var = ctx.OutputVar("Out"); + T value; + if (str_value.empty()) { + value = static_cast(float_value); + } else { + std::stringstream convert_stream(str_value); + if (std::is_same::value) { + int64_t tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } else { + double tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } + } auto shape = GetShape(ctx); if (out_var->IsType()) { @@ -96,15 +112,23 @@ class FillConstantKernel : public framework::OpKernel { "supports SelectedRows and LoDTensor"); } - if (force_cpu) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); + bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); + if (cpu_place) { tensor->mutable_data(platform::CPUPlace(), data_type); - } else { + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + tensor, static_cast(value)); + } +#ifdef PADDLE_WITH_CUDA + if (!cpu_place) { tensor->mutable_data(ctx.GetPlace(), data_type); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + tensor, static_cast(value)); } - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); - math::set_constant(dev_ctx, tensor, value); +#endif } }; } // namespace operators diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index acdf609481a..3366851711f 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -552,6 +552,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): 'force_cpu': force_cpu or force_init_on_cpu() } + if convert_dtype(dtype) in ['int64', 'int32']: + attrs['str_value'] = str(int(value)) + else: + attrs['str_value'] = str(float(value)) + def _contain_var(one_list): for ele in one_list: if isinstance(ele, Variable): -- GitLab