未验证 提交 6514f52e 编写于 作者: W wangchaochaohu 提交者: GitHub

fix the fill_constant op precious problem (#21322)

* fix the fill_constant op precious problem test=develop
上级 08c19c58
...@@ -90,8 +90,12 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -90,8 +90,12 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"The shape of the element in vector must be [1].") "The shape of the element in vector must be [1].")
.AsDuplicable() .AsDuplicable()
.AsDispensable(); .AsDispensable();
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0.0f) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<std::string>(
"str_value",
"(string, default empty) The str convert to value to be filled")
.SetDefault("");
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu " "(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running " "memory. Otherwise, fill output variable to the running "
......
...@@ -14,8 +14,9 @@ limitations under the License. */ ...@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <sstream>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -75,13 +76,28 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -75,13 +76,28 @@ class FillConstantKernel : public framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_type = auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value"); auto str_value = ctx.Attr<std::string>("str_value");
auto float_value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu"); auto force_cpu = ctx.Attr<bool>("force_cpu");
framework::Tensor *tensor = nullptr; framework::Tensor *tensor = nullptr;
framework::Variable *out_var = ctx.OutputVar("Out"); framework::Variable *out_var = ctx.OutputVar("Out");
T value;
if (str_value.empty()) {
value = static_cast<T>(float_value);
} else {
std::stringstream convert_stream(str_value);
if (std::is_same<int64_t, T>::value) {
int64_t tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
} else {
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
}
}
auto shape = GetShape(ctx); auto shape = GetShape(ctx);
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
...@@ -96,15 +112,23 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -96,15 +112,23 @@ class FillConstantKernel : public framework::OpKernel<T> {
"supports SelectedRows and LoDTensor"); "supports SelectedRows and LoDTensor");
} }
if (force_cpu) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
tensor->mutable_data(platform::CPUPlace(), data_type); tensor->mutable_data(platform::CPUPlace(), data_type);
} else { math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type); tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
} }
#endif
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
math::set_constant(dev_ctx, tensor, value);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -552,6 +552,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -552,6 +552,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
'force_cpu': force_cpu or force_init_on_cpu() 'force_cpu': force_cpu or force_init_on_cpu()
} }
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
def _contain_var(one_list): def _contain_var(one_list):
for ele in one_list: for ele in one_list:
if isinstance(ele, Variable): if isinstance(ele, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册