// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/random_crop_op.h" namespace paddle { namespace operators { class RandomCropOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } }; class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", ""); AddOutput("Out", ""); AddInput("Seed", ""); AddOutput("SeedOut", "").AsDispensable(); AddAttr>("shape", ""); AddComment(""); } }; class RandomCropOpInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { auto seed_dim = ctx->GetInputDim("Seed"); PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1); auto shape = ctx->Attrs().Get>("shape"); auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); auto out_dim = framework::vectorize2int(x_dim); for (size_t i = 1; i <= shape.size(); ++i) { size_t x_i = x_dim.size() - i; size_t shape_i = shape.size() - i; PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); out_dim[x_i] = shape[shape_i]; } ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); ctx->SetOutputDim("SeedOut", framework::make_ddim({1})); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace f = paddle::framework; REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, ops::RandomCropOpInferShape, f::EmptyGradOpMaker); template using Kernel = ops::RandomCropKernel; REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, Kernel, Kernel);