diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 09fa13dfbbf4f850f69126aa09849421e671d29e..33fa9b792879b0dab1f4c5efb4512da950ca4496 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -29,6 +29,10 @@ class CropOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dim = ctx.Input("X")->dims(); auto Y = ctx.Input("Y"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CropOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of CropOp should not be null."); if (Y == nullptr) { auto shape = Attr>("shape"); PADDLE_ENFORCE_EQ( @@ -40,6 +44,9 @@ class CropOp : public framework::OperatorWithKernel { } ctx.Output("Out")->Resize(framework::make_ddim(tensor_shape)); } else { + PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(Y->dims()), + "Tensor rank of both CropOp's " + "inputs must be same."); ctx.Output("Out")->Resize(Y->dims()); } } diff --git a/paddle/operators/crop_op.cu b/paddle/operators/crop_op.cu index 1715b2eaf9ccb4363311219e0245f965fb107b8e..561dbe48039e444429e0593f37cfd82bddce6471 100644 --- a/paddle/operators/crop_op.cu +++ b/paddle/operators/crop_op.cu @@ -19,8 +19,7 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; +using framework::LoDTensor; template __global__ void CropKernel(const int N, const int64_t* out_shape, diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h index 7f041737a7e87d76fc6208bb2e05e881b45badab..09d42f4b7ee44091b0e41f608e3db96281e740ba 100644 --- a/paddle/operators/crop_op.h +++ b/paddle/operators/crop_op.h @@ -24,8 +24,7 @@ template using EigenTensor = framework::EigenTensor; -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; +using framework::LoDTensor; template void CropGradFunction(const framework::ExecutionContext& context) { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index a7a38339fb2c8689778b0a86d3713f67e1447a80..c7009a604f60cda11434ad33b6c7d7caee1befdd 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/cond_op.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" @@ -288,6 +289,28 @@ All parameter, weight, gradient are variables in Paddle. [](operators::RecurrentOp &self, const operators::NetOp &net) -> void { self.set_stepnet(net.Clone()); }); + // cond_op + py::class_(m, "CondOp") + .def_static("create", + [](py::bytes protobin) -> operators::CondOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto cond_op = OpRegistry::CreateOp(desc); + return static_cast(cond_op.release()); + }) + .def("set_truenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_truenet(net.Clone()); + }) + .def("set_falsenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_falsenet(net.Clone()); + }); + m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU);