diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index 4a0f41ae54d768a71dd883efc020f1674bcef89f..edc538c5056697c2f8c65bdefa1f31ce7d0c8ab8 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" +#include #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -30,6 +31,15 @@ class ShapeOp : public framework::OperatorWithKernel { auto in_dim = ctx->GetInputDim("Input"); ctx->SetOutputDim("Out", {in_dim.size()}); } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } }; class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {