提交 0c05ea39 编写于 作者: W wanghaoshuang

Pull latest pybind.cc to crop_op

上级 46888c32
...@@ -29,6 +29,10 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -29,6 +29,10 @@ class CropOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<LoDTensor>("X")->dims(); auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto Y = ctx.Input<LoDTensor>("Y"); auto Y = ctx.Input<LoDTensor>("Y");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
if (Y == nullptr) { if (Y == nullptr) {
auto shape = Attr<std::vector<int>>("shape"); auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -40,6 +44,9 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -40,6 +44,9 @@ class CropOp : public framework::OperatorWithKernel {
} }
ctx.Output<LoDTensor>("Out")->Resize(framework::make_ddim(tensor_shape)); ctx.Output<LoDTensor>("Out")->Resize(framework::make_ddim(tensor_shape));
} else { } else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(Y->dims()),
"Tensor rank of both CropOp's "
"inputs must be same.");
ctx.Output<LoDTensor>("Out")->Resize(Y->dims()); ctx.Output<LoDTensor>("Out")->Resize(Y->dims());
} }
} }
......
...@@ -19,8 +19,7 @@ ...@@ -19,8 +19,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int D> template <typename T, int D>
__global__ void CropKernel(const int N, const int64_t* out_shape, __global__ void CropKernel(const int N, const int64_t* out_shape,
......
...@@ -24,8 +24,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor, ...@@ -24,8 +24,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor; using framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T, size_t D> template <typename Place, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) { void CropGradFunction(const framework::ExecutionContext& context) {
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/cond_op.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -288,6 +289,28 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -288,6 +289,28 @@ All parameter, weight, gradient are variables in Paddle.
[](operators::RecurrentOp &self, const operators::NetOp &net) [](operators::RecurrentOp &self, const operators::NetOp &net)
-> void { self.set_stepnet(net.Clone()); }); -> void { self.set_stepnet(net.Clone()); });
// cond_op
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
.def_static("create",
[](py::bytes protobin) -> operators::CondOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc);
return static_cast<operators::CondOp *>(cond_op.release());
})
.def("set_truenet",
[](operators::CondOp &self, const operators::NetOp &net) -> void {
self.set_truenet(net.Clone());
})
.def("set_falsenet",
[](operators::CondOp &self, const operators::NetOp &net) -> void {
self.set_falsenet(net.Clone());
});
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册