未验证 提交 90f4d5e9 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

modify fill constant batch size like (#5222)

上级 08ca7267
...@@ -36,7 +36,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -36,7 +36,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
[](int a) { return static_cast<int64_t>(a); }); [](int a) { return static_cast<int64_t>(a); });
auto dims = framework::make_ddim(shape_int64); auto dims = framework::make_ddim(shape_int64);
dims[0] = ctx->GetInputDim("Input")[0]; int dim_idx = ctx->Attrs().Get<int>("dim_idx");
PADDLE_ENFORCE_GE(dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), dim_idx);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), dim_idx);
dims[dim_idx] = ctx->GetInputDim("Input")[dim_idx];
ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Out", dims);
} }
...@@ -57,15 +62,18 @@ class FillConstantBatchSizeLikeOpMaker ...@@ -57,15 +62,18 @@ class FillConstantBatchSizeLikeOpMaker
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddInput("Input", AddInput("Input",
"(Tensor) Tensor " "(Tensor) Tensor "
"whose first dimension is used to specify the batch_size"); "whose dim_idx th dimension is used to specify the batch_size");
AddOutput("Out", AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled " "(Tensor) Tensor of specified shape will be filled "
"with the specified value"); "with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("dim_idx",
"(int, default 0) the index of batch size dimension")
.SetDefault(0);
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddComment(R"DOC(Fill up a variable with specified constant value.)DOC"); AddComment(R"DOC(Fill up a variable with specified constant value.)DOC");
} }
}; };
......
...@@ -3,13 +3,27 @@ import numpy as np ...@@ -3,13 +3,27 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
class TestFillConstantBatchSizeLikeOp(OpTest): class TestFillConstantBatchSizeLikeWhenFirstDimIsBatchSize(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_constant_batch_size_like" self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")} self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [-1, 132, 777]} self.attrs = {'value': 3.5, 'shape': [-1, 132, 7]}
out = np.random.random((219, 132, 777)).astype("float32") out = np.random.random((219, 132, 7)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {'Input': np.random.random((219, 232)).astype("float32")}
self.attrs = {'value': 3.5, 'shape': [132, -1, 7], 'dim_idx': 1}
out = np.random.random((132, 232, 7)).astype("float32")
out.fill(3.5) out.fill(3.5)
self.outputs = {'Out': out} self.outputs = {'Out': out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册