未验证 提交 4c9b3daf 编写于 作者: W wangchaochaohu 提交者: GitHub

fill_constant_batch_size_like OP precious problem fix (#21337)

* fix fill_constant_batch_size_like_op precious problem  test=develop
上级 46401786
...@@ -38,6 +38,8 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { ...@@ -38,6 +38,8 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "default 0. The value to be filled") AddAttr<float>("value", "default 0. The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<std::string>("str_value", "default empty. The value to be filled")
.SetDefault("");
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu " "(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running " "memory. Otherwise, fill output variable to the running "
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -22,14 +23,15 @@ namespace operators { ...@@ -22,14 +23,15 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto data_type = auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value"); auto float_value = ctx.Attr<float>("value");
auto str_value = ctx.Attr<std::string>("str_value");
auto force_cpu = ctx.Attr<bool>("force_cpu"); auto force_cpu = ctx.Attr<bool>("force_cpu");
auto* out = ctx.Output<framework::Tensor>("Out"); auto *out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input"); auto *in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) { if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor. // set the correct batch size for the LoDTensor.
auto odims = out->dims(); auto odims = out->dims();
...@@ -38,15 +40,39 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { ...@@ -38,15 +40,39 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(odims, ctx.GetPlace()); out->mutable_data<T>(odims, ctx.GetPlace());
} }
if (force_cpu) { T value;
out->mutable_data(platform::CPUPlace(), data_type); if (str_value.empty()) {
value = static_cast<T>(float_value);
} else { } else {
out->mutable_data(ctx.GetPlace(), data_type); std::stringstream convert_stream(str_value);
if (std::is_same<int64_t, T>::value) {
int64_t tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
} else {
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
}
} }
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(ctx.GetPlace()); auto &dev_ctx = *pool.Get(ctx.GetPlace());
math::set_constant(dev_ctx, out, value); bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
math::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(), data_type);
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
#endif
} }
}; };
......
...@@ -668,18 +668,23 @@ def fill_constant_batch_size_like(input, ...@@ -668,18 +668,23 @@ def fill_constant_batch_size_like(input,
""" """
helper = LayerHelper("fill_constant_batch_size_like", **locals()) helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( attrs = {
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape, 'shape': shape,
'dtype': out.dtype, 'dtype': out.dtype,
'value': float(value), 'value': float(value),
'input_dim_idx': input_dim_idx, 'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx, 'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu or force_init_on_cpu() 'force_cpu': force_cpu or force_init_on_cpu()
}) }
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True out.stop_gradient = True
return out return out
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
...@@ -52,6 +53,20 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest): ...@@ -52,6 +53,20 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
self.check_output() self.check_output()
class TestFillConstantBatchSizeLikeInt64(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("int64")}
self.attrs = {'value': 5894589485094, 'shape': [-1, 132, 7]}
out = np.random.random((219, 132, 7)).astype("int64")
out.fill(5894589485094)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest): class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_constant_batch_size_like" self.op_type = "fill_constant_batch_size_like"
...@@ -74,5 +89,20 @@ class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest): ...@@ -74,5 +89,20 @@ class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
self.check_output() self.check_output()
# Test python API
class TestFillConstantBatchSizeLikeAPI(OpTest):
def test_api(self):
like = fluid.layers.fill_constant(
shape=[1, 200], value=10, dtype='int64')
out = fluid.layers.fill_constant_batch_size_like(
input=like, shape=[2, 300], value=1315454564656, dtype='int64')
exe = fluid.Executor(place=fluid.CPUPlace())
res, = exe.run(fluid.default_main_program(), fetch_list=[out])
assert np.array_equal(
res[0], np.full(
[300], 1315454564656, dtype="int64"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -2682,6 +2682,14 @@ class TestBook(LayerTest): ...@@ -2682,6 +2682,14 @@ class TestBook(LayerTest):
x, axes=axes, starts=starts, ends=ends, strides=strides) x, axes=axes, starts=starts, ends=ends, strides=strides)
return out return out
def test_fill_constant_batch_size_like(self):
with self.static_graph():
like = fluid.layers.fill_constant(
shape=[1, 200], value=10, dtype='int64')
out = layers.fill_constant_batch_size_like(
input=like, shape=[2, 3300], value=1315454564656, dtype='int64')
return out
def test_psroi_pool(self): def test_psroi_pool(self):
# TODO(minqiyang): dygraph do not support lod now # TODO(minqiyang): dygraph do not support lod now
with self.static_graph(): with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册