未验证 提交 4c9b3daf 编写于 作者: W wangchaochaohu 提交者: GitHub

fill_constant_batch_size_like OP precious problem fix (#21337)

* fix fill_constant_batch_size_like_op precious problem  test=develop
上级 46401786
......@@ -38,6 +38,8 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32);
AddAttr<float>("value", "default 0. The value to be filled")
.SetDefault(0.0f);
AddAttr<std::string>("str_value", "default empty. The value to be filled")
.SetDefault("");
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -22,14 +23,15 @@ namespace operators {
template <typename DeviceContext, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value");
auto float_value = ctx.Attr<float>("value");
auto str_value = ctx.Attr<std::string>("str_value");
auto force_cpu = ctx.Attr<bool>("force_cpu");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
......@@ -38,15 +40,39 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(odims, ctx.GetPlace());
}
if (force_cpu) {
out->mutable_data(platform::CPUPlace(), data_type);
T value;
if (str_value.empty()) {
value = static_cast<T>(float_value);
} else {
out->mutable_data(ctx.GetPlace(), data_type);
std::stringstream convert_stream(str_value);
if (std::is_same<int64_t, T>::value) {
int64_t tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
} else {
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
}
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(ctx.GetPlace());
math::set_constant(dev_ctx, out, value);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
math::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(), data_type);
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
math::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(), data_type);
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
#endif
}
};
......
......@@ -668,18 +668,23 @@ def fill_constant_batch_size_like(input,
"""
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
attrs = {
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu or force_init_on_cpu()
}
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu or force_init_on_cpu()
})
attrs=attrs)
out.stop_gradient = True
return out
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest
......@@ -52,6 +53,20 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
self.check_output()
class TestFillConstantBatchSizeLikeInt64(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("int64")}
self.attrs = {'value': 5894589485094, 'shape': [-1, 132, 7]}
out = np.random.random((219, 132, 7)).astype("int64")
out.fill(5894589485094)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
......@@ -74,5 +89,20 @@ class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
self.check_output()
# Test python API
class TestFillConstantBatchSizeLikeAPI(OpTest):
def test_api(self):
like = fluid.layers.fill_constant(
shape=[1, 200], value=10, dtype='int64')
out = fluid.layers.fill_constant_batch_size_like(
input=like, shape=[2, 300], value=1315454564656, dtype='int64')
exe = fluid.Executor(place=fluid.CPUPlace())
res, = exe.run(fluid.default_main_program(), fetch_list=[out])
assert np.array_equal(
res[0], np.full(
[300], 1315454564656, dtype="int64"))
if __name__ == "__main__":
unittest.main()
......@@ -2682,6 +2682,14 @@ class TestBook(LayerTest):
x, axes=axes, starts=starts, ends=ends, strides=strides)
return out
def test_fill_constant_batch_size_like(self):
with self.static_graph():
like = fluid.layers.fill_constant(
shape=[1, 200], value=10, dtype='int64')
out = layers.fill_constant_batch_size_like(
input=like, shape=[2, 3300], value=1315454564656, dtype='int64')
return out
def test_psroi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册