提交 a7fe3b50 编写于 作者: P phlrain

fix concat; test=develop

上级 909e1341
...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
...@@ -47,9 +50,19 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -47,9 +50,19 @@ class ConcatOp : public framework::OperatorWithKernel {
if (j == axis) { if (j == axis) {
out_dims[axis] += ins[i][j]; out_dims[axis] += ins[i][j];
} else { } else {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], if (ctx->IsRuntime()) {
"Input tensors should have the same " // check all shape in run time
"elements except the specify axis."); PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
} else {
// not check -1 in compile time
if (out_dims[j] != -1 && ins[i][j] != -1) {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
}
}
} }
} }
} }
...@@ -59,6 +72,22 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -59,6 +72,22 @@ class ConcatOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]);
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -66,6 +95,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,6 +95,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator."); AddOutput("Out", "Output tensor of concat operator.");
AddAttr<bool>(
"use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false);
AddAttr<int>("axis", AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.") "The axis along which the input tensors will be concatenated.")
.SetDefault(0); .SetDefault(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册