提交 dc6e8146 编写于 作者: P phlrain

fix concat shape; test=develop

上级 909e1341
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
namespace paddle {
namespace operators {
using framework::Tensor;
......@@ -45,11 +49,29 @@ class ConcatOp : public framework::OperatorWithKernel {
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += ins[i][j];
if (ctx->IsRuntime()) {
out_dims[axis] += ins[i][j];
} else {
if (out_dims[axis] == -1 || ins[i][j] == -1) {
out_dims[axis] = -1;
} else {
out_dims[axis] += ins[i][j];
}
}
} else {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
if (ctx->IsRuntime()) {
// check all shape in run time
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
} else {
// not check -1 with other in compile time
if (out_dims[j] > 0 && ins[i][j] > 0) {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
}
}
}
}
}
......@@ -59,6 +81,22 @@ class ConcatOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]);
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -66,6 +104,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<bool>(
"use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
.SetDefault(0);
......@@ -87,11 +129,7 @@ Examples:
class ConcatOpGrad : public framework::OperatorWithKernel {
public:
ConcatOpGrad(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X";
......@@ -109,6 +147,33 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
"X");
class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("concat_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
......@@ -116,9 +181,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad);
ops::ConcatGradOpDescMaker);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatOpGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册