From 2c29cf1ea5ebf1ee73090e1002690d480af252d1 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 19 Sep 2017 01:06:16 +0800 Subject: [PATCH] Use Tensor as the temp variables instead of CUDA api --- paddle/operators/crop_op.cc | 46 +++++++++++++++++----------------- paddle/operators/crop_op.cu | 50 ++++++++++++++++++------------------- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 33fa9b79287..ee4bc9cdafb 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -27,12 +27,12 @@ class CropOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto x_dim = ctx.Input("X")->dims(); - auto Y = ctx.Input("Y"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) of CropOp should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) of CropOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); + auto Y = ctx.Input("Y"); if (Y == nullptr) { auto shape = Attr>("shape"); PADDLE_ENFORCE_EQ( @@ -40,7 +40,7 @@ class CropOp : public framework::OperatorWithKernel { "Shape size should be equal to dimention size of input tensor."); std::vector tensor_shape(shape.size()); for (size_t i = 0; i < shape.size(); ++i) { - tensor_shape[i] = (int64_t)shape[i]; + tensor_shape[i] = static_cast(shape[i]); } ctx.Output("Out")->Resize(framework::make_ddim(tensor_shape)); } else { @@ -65,6 +65,15 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output of crop op " "with the same dimension as X."); + AddAttr>("offsets", + "A list describing offsets to be cropped." + "The size of offsets list should be as same as " + "dimension size of input X."); + AddAttr>("shape", + "A list describing the shape of output." + "The size of shape list should be as same as " + "dimension size of input X.") + .SetDefault(std::vector()); AddComment(R"DOC( Crop Operator. Crop input into output, as specified by offsets and shape. @@ -81,33 +90,24 @@ The input should be a k-D tensor(k > 0 and k < 7). As an example: Given: -X = [[0, 1, 2, 0, 0] - [0, 3, 4, 0, 0] - [0, 0, 0, 0, 0]] + X = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] and -offsets = [0, 1] + offsets = [0, 1] and -shape = [2, 2] + shape = [2, 2] then we get -Out = [[1, 2], - [3, 4]] + Out = [[1, 2], + [3, 4]] )DOC"); - AddAttr>("offsets", - "A list describing offsets to be cropped." - "The size of offsets list should be as same as " - "dimension size of input X."); - AddAttr>("shape", - "A list describing the shape of output." - "The size of shape list should be as same as " - "dimension size of input X.") - .SetDefault(std::vector()); } }; @@ -149,17 +149,17 @@ template class CropCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *out = context.Output("Out"); + auto *x = context.Input("X"); + auto *out = context.Output("Out"); auto x_data = x->data(); T *out_data = out->mutable_data(context.GetPlace()); auto x_dims = x->dims(); auto out_dims = out->dims(); - int64_t out_count = framework::product(out_dims); + int64_t out_count = out->numel(); std::vector x_shape = framework::vectorize(x_dims); std::vector out_shape = framework::vectorize(out_dims); - auto offsets = context.op().Attr>("offsets"); + auto offsets = context.Attr>("offsets"); PADDLE_ENFORCE_EQ( x_dims.size(), offsets.size(), "Offsets size should be equal to dimension size of input tensor."); diff --git a/paddle/operators/crop_op.cu b/paddle/operators/crop_op.cu index a40eb7af385..f499ce3f275 100644 --- a/paddle/operators/crop_op.cu +++ b/paddle/operators/crop_op.cu @@ -20,6 +20,7 @@ namespace paddle { namespace operators { using framework::LoDTensor; +using framework::Tensor; template __global__ void CropKernel(const int N, const int64_t* out_shape, @@ -54,35 +55,36 @@ void CropCUDAFunctoin(const framework::ExecutionContext& context) { T* out_data = out->mutable_data(paddle::platform::GPUPlace()); auto x_dims = x->dims(); auto out_dims = out->dims(); - int64_t out_count = framework::product(out_dims); - int64_t x_shape[D]; - int64_t out_shape[D]; + int64_t out_count = out->numel(); + Tensor x_shape; + Tensor out_shape; + int64_t* x_shape_data = + x_shape.mutable_data({D}, paddle::platform::CPUPlace()); + int64_t* out_shape_data = + out_shape.mutable_data({D}, paddle::platform::CPUPlace()); for (int i = 0; i < D; ++i) { - x_shape[i] = x_dims[i]; - out_shape[i] = out_dims[i]; + x_shape_data[i] = x_dims[i]; + out_shape_data[i] = out_dims[i]; } - int64_t* x_shape_gpu; - int64_t* out_shape_gpu; - cudaMalloc((void**)&x_shape_gpu, sizeof(int64_t) * D); - cudaMemcpy(x_shape_gpu, x_shape, sizeof(int64_t) * D, cudaMemcpyHostToDevice); - cudaMalloc((void**)&out_shape_gpu, sizeof(int64_t) * D); - cudaMemcpy(out_shape_gpu, out_shape, sizeof(int64_t) * D, - cudaMemcpyHostToDevice); + Tensor x_shape_gpu; + Tensor out_shape_gpu; + x_shape_gpu.CopyFrom(x_shape, paddle::platform::GPUPlace()); + out_shape_gpu.CopyFrom(out_shape, paddle::platform::GPUPlace()); auto offsets = context.op().Attr>("offsets"); PADDLE_ENFORCE_EQ( D, offsets.size(), "Offsets size should be equal to dimension size of input tensor."); - int crop_rules[D * 2]; - for (size_t i = 0; i < x_dims.size(); ++i) { - crop_rules[i * 2] = offsets[i]; - crop_rules[i * 2 + 1] = x_dims[i] - out_dims[i] - offsets[i]; + Tensor crop_rules; + int* crop_rules_data = + crop_rules.mutable_data({D * 2}, paddle::platform::CPUPlace()); + for (size_t i = 0; i < D; ++i) { + crop_rules_data[i * 2] = offsets[i]; + crop_rules_data[i * 2 + 1] = x_dims[i] - out_dims[i] - offsets[i]; } - int* crop_rules_gpu; - cudaMalloc((void**)&crop_rules_gpu, sizeof(int) * D * 2); - cudaMemcpy(crop_rules_gpu, crop_rules, sizeof(int) * D * 2, - cudaMemcpyHostToDevice); + Tensor crop_rules_gpu; + crop_rules_gpu.CopyFrom(crop_rules, paddle::platform::GPUPlace()); int n = out_dims[0]; int d = out_dims[1]; @@ -94,11 +96,9 @@ void CropCUDAFunctoin(const framework::ExecutionContext& context) { CropKernel<<(device_context) - ->stream()>>>(out_count, out_shape_gpu, x_shape_gpu, - crop_rules_gpu, x_data, out_data); - cudaFree(crop_rules_gpu); - cudaFree(x_shape_gpu); - cudaFree(out_shape_gpu); + ->stream()>>>( + out_count, out_shape_gpu.data(), x_shape_gpu.data(), + crop_rules_gpu.data(), x_data, out_data); } template -- GitLab