/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { // Internal template using EigenTensor = framework::EigenTensor; using framework::Tensor; template class CropKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* out = context.Output("Out"); const T* x_data = x->data(); T* out_data = out->mutable_data(context.GetPlace()); auto x_stride = framework::stride(x->dims()); auto out_stride = framework::stride(out->dims()); auto offsets = context.Attr>("offsets"); PADDLE_ENFORCE_EQ( x->dims().size(), static_cast(offsets.size()), "Offsets size should be equal to dimension size of input tensor."); int64_t offset = 0; for (size_t i = 0; i < offsets.size(); ++i) { offset += (x_stride[i] * offsets[i]); } StridedMemcpy(context.device_context(), x_data + offset, x_stride, out->dims(), out_stride, out_data); } }; template void CropGradFunction(const framework::ExecutionContext& context) { auto* d_x = context.Output(framework::GradVarName("X")); if (d_x != nullptr) { auto* d_out = context.Input(framework::GradVarName("Out")); d_x->mutable_data(context.GetPlace()); auto offsets = context.Attr>("offsets"); Eigen::array, D> paddings; for (size_t i = 0; i < D; ++i) { paddings[i].first = offsets[i]; paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; } auto d_x_tensor = EigenTensor::From(*d_x); auto d_out_tensor = EigenTensor::From(*d_out); d_x_tensor.device(context.GetEigenDevice()) = d_out_tensor.pad(paddings, 0); } } template class CropGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { size_t rank = context.Input(framework::GradVarName("Out"))->dims().size(); switch (rank) { case 1: CropGradFunction(context); break; case 2: CropGradFunction(context); break; case 3: CropGradFunction(context); break; case 4: CropGradFunction(context); break; case 5: CropGradFunction(context); break; case 6: CropGradFunction(context); break; default: PADDLE_THROW( "CropOp only support tensors with no more than 6 dimensions."); } } }; } // namespace operators } // namespace paddle