未验证 提交 566bf2ec 编写于 作者: T tensor-tang 提交者: GitHub

concat op support negative axis (#18045)

test=develop
上级 7e463c84
......@@ -36,7 +36,10 @@ class ConcatOp : public framework::OperatorWithKernel {
"Output(Out) of ConcatOp should not be null.");
auto ins = ctx->GetInputsDim("X");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t axis =
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
static_cast<int64_t>(ins[0].size()));
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
......@@ -115,7 +118,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
"The axis along which the input tensors will be concatenated."
"The axis could also be negative numbers. Negative axis is "
"interpreted as counting from the end of the rank."
"i.e., axis + rank(X) th dimension.")
.SetDefault(0);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
......
......@@ -23,13 +23,22 @@ limitations under the License. */
namespace paddle {
namespace operators {
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
if (axis < 0) {
axis = axis + rank;
}
return axis > 0 ? axis : 0;
}
template <typename DeviceContext, typename T>
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
PADDLE_ENFORCE(ins[0], "The input should not be null.");
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
static_cast<int64_t>(ins[0]->dims().size()));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
......@@ -83,8 +92,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
}
}
}
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
PADDLE_ENFORCE(ins[0], "The input should not be null.");
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
......
......@@ -25,9 +25,15 @@ class TestConcatOp(OpTest):
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
else:
self.actual_axis = self.axis
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis)
(self.x0, self.x1, self.x2), axis=self.actual_axis)
}
def test_check_output(self):
......@@ -75,5 +81,13 @@ class TestConcatOp4(TestConcatOp):
pass
class TestConcatOp5(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.axis = -3
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册