未验证 提交 aab4d12c 编写于 作者: J jerrywgz 提交者: GitHub

refine GetExpectedKernelType in conat op, test=develop (#17934)

上级 3ece61f7
......@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::Tensor;
using Tensor = framework::Tensor;
class ConcatOp : public framework::OperatorWithKernel {
public:
......@@ -80,12 +80,12 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto vars = ctx.MultiInputVar("X");
auto inputs = ctx.MultiInput<Tensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *var : vars) {
if (var->IsInitialized()) {
input_data_type = framework::GetDataTypeOfVar(var);
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = input->type();
flag = 1;
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册