提交 57a3b8b6 编写于 作者: W wanghaoshuang

1. Implement GPUCrop kernel instead of eigen.

2. Fix unitest
上级 57011b20
...@@ -19,6 +19,7 @@ namespace paddle { ...@@ -19,6 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
using framework::LoDTensor;
class CropOp : public framework::OperatorWithKernel { class CropOp : public framework::OperatorWithKernel {
public: public:
...@@ -26,8 +27,8 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -26,8 +27,8 @@ class CropOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto Y = ctx.Input<Tensor>("Y"); auto Y = ctx.Input<LoDTensor>("Y");
if (Y == nullptr) { if (Y == nullptr) {
auto shape = Attr<std::vector<int>>("shape"); auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -37,9 +38,9 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -37,9 +38,9 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = (int64_t)shape[i]; tensor_shape[i] = (int64_t)shape[i];
} }
ctx.Output<Tensor>("Out")->Resize(framework::make_ddim(tensor_shape)); ctx.Output<LoDTensor>("Out")->Resize(framework::make_ddim(tensor_shape));
} else { } else {
ctx.Output<Tensor>("Out")->Resize(Y->dims()); ctx.Output<LoDTensor>("Out")->Resize(Y->dims());
} }
} }
}; };
...@@ -112,8 +113,8 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -112,8 +113,8 @@ class CropOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) { if (x_grad != nullptr) {
x_grad->Resize(x_dims); x_grad->Resize(x_dims);
} }
...@@ -141,23 +142,17 @@ template <typename T> ...@@ -141,23 +142,17 @@ template <typename T>
class CropCPUKernel : public framework::OpKernel { class CropCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
LOG(INFO) << "CropCPUKernel step1"; auto *x = context.Input<LoDTensor>("X");
auto *x = context.Input<Tensor>("X"); auto *out = context.Output<LoDTensor>("Out");
LOG(INFO) << "CropCPUKernel step2";
auto *out = context.Output<Tensor>("Out");
LOG(INFO) << "CropCPUKernel step3";
auto x_data = x->data<T>(); auto x_data = x->data<T>();
T *out_data = out->mutable_data<T>(paddle::platform::CPUPlace()); T *out_data = out->mutable_data<T>(context.GetPlace());
LOG(INFO) << "CropCPUKernel step4";
auto x_dims = x->dims(); auto x_dims = x->dims();
auto out_dims = out->dims(); auto out_dims = out->dims();
LOG(INFO) << "CropCPUKernel step5";
int64_t out_count = framework::product(out_dims); int64_t out_count = framework::product(out_dims);
std::vector<int64_t> x_shape = framework::vectorize(x_dims); std::vector<int64_t> x_shape = framework::vectorize(x_dims);
std::vector<int64_t> out_shape = framework::vectorize(out_dims); std::vector<int64_t> out_shape = framework::vectorize(out_dims);
auto offsets = context.op().Attr<std::vector<int>>("offsets"); auto offsets = context.op().Attr<std::vector<int>>("offsets");
LOG(INFO) << "CropCPUKernel step6";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims.size(), offsets.size(), x_dims.size(), offsets.size(),
"Offsets size should be equal to dimension size of input tensor."); "Offsets size should be equal to dimension size of input tensor.");
...@@ -171,7 +166,6 @@ class CropCPUKernel : public framework::OpKernel { ...@@ -171,7 +166,6 @@ class CropCPUKernel : public framework::OpKernel {
for (int64_t i = 0; i < out_count; ++i) { for (int64_t i = 0; i < out_count; ++i) {
out_data[i] = x_data[transIndex(out_shape, x_shape, crop_rules, i)]; out_data[i] = x_data[transIndex(out_shape, x_shape, crop_rules, i)];
} }
LOG(INFO) << "CropCPUKernel step7";
} }
}; };
......
...@@ -20,6 +20,7 @@ namespace paddle { ...@@ -20,6 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int D> template <typename T, int D>
__global__ void CropKernel(const int N, const int64_t* out_shape, __global__ void CropKernel(const int N, const int64_t* out_shape,
...@@ -48,9 +49,8 @@ template <typename T, int D> ...@@ -48,9 +49,8 @@ template <typename T, int D>
void CropCUDAFunctoin(const framework::ExecutionContext& context) { void CropCUDAFunctoin(const framework::ExecutionContext& context) {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"It must use GPUPlace."); "It must use GPUPlace.");
LOG(INFO) << "CropCUDAFunctoin step1"; auto* x = context.Input<LoDTensor>("X");
auto* x = context.Input<Tensor>("X"); auto* out = context.Output<LoDTensor>("Out");
auto* out = context.Output<Tensor>("Out");
auto x_data = x->data<T>(); auto x_data = x->data<T>();
T* out_data = out->mutable_data<T>(paddle::platform::GPUPlace()); T* out_data = out->mutable_data<T>(paddle::platform::GPUPlace());
auto x_dims = x->dims(); auto x_dims = x->dims();
...@@ -100,7 +100,7 @@ template <typename T> ...@@ -100,7 +100,7 @@ template <typename T>
class CropOpCUDAKernel : public framework::OpKernel { class CropOpCUDAKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = context.Input<Tensor>("X")->dims().size(); size_t rank = context.Input<LoDTensor>("X")->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
CropCUDAFunctoin<T, 1>(context); CropCUDAFunctoin<T, 1>(context);
......
...@@ -25,11 +25,12 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor, ...@@ -25,11 +25,12 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T, size_t D> template <typename Place, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) { void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) { if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace()); d_x->mutable_data<T>(context.GetPlace());
auto d_x_dims = d_x->dims(); auto d_x_dims = d_x->dims();
...@@ -52,7 +53,7 @@ class CropGradKernel : public framework::OpKernel { ...@@ -52,7 +53,7 @@ class CropGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size(); context.Input<LoDTensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
CropGradFunction<Place, T, 1>(context); CropGradFunction<Place, T, 1>(context);
......
...@@ -64,7 +64,6 @@ def set_input(scope, op, inputs, place): ...@@ -64,7 +64,6 @@ def set_input(scope, op, inputs, place):
tensor.set_dims(in_array.shape) tensor.set_dims(in_array.shape)
tensor.set(in_array, place) tensor.set(in_array, place)
if isinstance(in_val, tuple): if isinstance(in_val, tuple):
print "set lod"
tensor.set_lod(in_val[1]) tensor.set_lod(in_val[1])
...@@ -189,10 +188,8 @@ class OpTest(unittest.TestCase): ...@@ -189,10 +188,8 @@ class OpTest(unittest.TestCase):
self.op.infer_shape(self.scope) self.op.infer_shape(self.scope)
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
self.op.run(self.scope, ctx) self.op.run(self.scope, ctx)
print "finish self.op.run"
for out_name, out_dup in Operator.get_op_outputs(self.op.type()): for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
print "finish Operator.get_op_outputs"
print "out_dup=%s; out_name=%s" % (out_dup, out_name)
if out_dup: if out_dup:
sub_out = self.outputs[out_name] sub_out = self.outputs[out_name]
for sub_out_name in sub_out: for sub_out_name in sub_out:
...@@ -204,17 +201,12 @@ class OpTest(unittest.TestCase): ...@@ -204,17 +201,12 @@ class OpTest(unittest.TestCase):
actual, expect, atol=1e-05), actual, expect, atol=1e-05),
"output name: " + out_name + "has diff") "output name: " + out_name + "has diff")
else: else:
v = self.scope.find_var(out_name)
print "var=%s" % v
print "tensor=%s" % v.get_tensor()
actual = np.array(self.scope.find_var(out_name).get_tensor()) actual = np.array(self.scope.find_var(out_name).get_tensor())
print "actual=%s" % actual
expect = self.outputs[out_name] expect = self.outputs[out_name]
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=1e-05), actual, expect, atol=1e-05),
"output name: " + out_name + "has diff") "output name: " + out_name + "has diff")
print "finish check in %s" % place
def check_output(self): def check_output(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
...@@ -47,45 +47,44 @@ class TestCropOp(OpTest): ...@@ -47,45 +47,44 @@ class TestCropOp(OpTest):
def initTestCase(self): def initTestCase(self):
self.x_shape = (8, 8) self.x_shape = (8, 8)
self.crop_shape = [2, 2] self.crop_shape = (2, 2)
self.offsets = [1, 2] self.offsets = [1, 2]
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
print "finish check_output"
def test_check_grad_normal(self):
#def test_check_grad_normal(self): self.check_grad(['X'], 'Out', max_relative_error=0.006)
# self.check_grad(['X'], 'Out', max_relative_error=0.006)
#class TestCase1(TestCropOp): class TestCase1(TestCropOp):
# def initTestCase(self): def initTestCase(self):
# self.x_shape = (16, 16, 16) self.x_shape = (16, 8, 32)
# self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
# self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
#
#
#class TestCase2(TestCropOp): class TestCase2(TestCropOp):
# def initTestCase(self): def initTestCase(self):
# self.x_shape = (4, 4) self.x_shape = (4, 8)
# self.crop_shape = [4, 4] self.crop_shape = [4, 8]
# self.offsets = [0, 0] self.offsets = [0, 0]
#
#
#class TestCase3(TestCropOp): class TestCase3(TestCropOp):
# def initTestCase(self): def initTestCase(self):
# self.x_shape = (16, 16, 16) self.x_shape = (4, 8, 16)
# self.crop_shape = [2, 2, 3] self.crop_shape = [2, 2, 3]
# self.offsets = [1, 5, 3] self.offsets = [1, 5, 3]
# self.crop_by_input = True self.crop_by_input = True
#
#
#class TestCase4(TestCropOp): class TestCase4(TestCropOp):
# def initTestCase(self): def initTestCase(self):
# self.x_shape = (4, 4) self.x_shape = (4, 4)
# self.crop_shape = [4, 4] self.crop_shape = [4, 4]
# self.offsets = [0, 0] self.offsets = [0, 0]
# self.crop_by_input = True self.crop_by_input = True
#
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册