未验证 提交 aab4d12c 编写于 作者: J jerrywgz 提交者: GitHub

refine GetExpectedKernelType in conat op, test=develop (#17934)

上级 3ece61f7
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using Tensor = framework::Tensor;
class ConcatOp : public framework::OperatorWithKernel { class ConcatOp : public framework::OperatorWithKernel {
public: public:
...@@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto vars = ctx.MultiInputVar("X"); auto inputs = ctx.MultiInput<Tensor>("X");
auto input_data_type = framework::proto::VarType::Type(0); auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0; bool flag = 0;
for (auto *var : vars) { for (auto *input : inputs) {
if (var->IsInitialized()) { if (input->IsInitialized() && input->numel() > 0) {
input_data_type = framework::GetDataTypeOfVar(var); input_data_type = input->type();
flag = 1; flag = 1;
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册