未验证 提交 c1aae8b8 编写于 作者: J jerrywgz 提交者: GitHub

Fix GetExpectedKernelType in Concat op (#17459)

* fix concat op vartype check, test=develop
上级 58f7695a
...@@ -80,8 +80,19 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -80,8 +80,19 @@ class ConcatOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto vars = ctx.MultiInputVar("X");
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *var : vars) {
if (var->IsInitialized()) {
input_data_type = framework::GetDataTypeOfVar(var);
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW("All Inputs of Concat OP are Empty!");
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册