From b1e33bea15ed2327e4a4a5b838c3fde8232b69fc Mon Sep 17 00:00:00 2001 From: duanboqiang Date: Wed, 10 Aug 2022 16:33:24 +0800 Subject: [PATCH] [phi] migration of class center sample infermeta (#45025) * add class center sample infershape * add yaml * modify unittest * modify unittest * remove comment --- .../fluid/operators/class_center_sample_op.cc | 33 +++++-------------- paddle/phi/api/yaml/legacy_api.yaml | 8 +++++ paddle/phi/infermeta/unary.cc | 29 ++++++++++++++++ paddle/phi/infermeta/unary.h | 11 +++++++ .../unittests/test_class_center_sample_op.py | 20 ++++++++++- python/paddle/nn/functional/common.py | 6 +++- 6 files changed, 80 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/class_center_sample_op.cc b/paddle/fluid/operators/class_center_sample_op.cc index c1132e5798..57f8bfb71f 100644 --- a/paddle/fluid/operators/class_center_sample_op.cc +++ b/paddle/fluid/operators/class_center_sample_op.cc @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,30 +24,6 @@ namespace operators { class ClassCenterSampleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Label"), "Input", "Label", "ClassCenterSample"); - OP_INOUT_CHECK(ctx->HasOutput("RemappedLabel"), - "Output", - "RemappedLabel", - "ClassCenterSample"); - OP_INOUT_CHECK(ctx->HasOutput("SampledLocalClassCenter"), - "Output", - "SampledLocalClassCenter", - "ClassCenterSample"); - - auto x_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), - 1, - platform::errors::InvalidArgument( - "Rank of Input(Label) should be equal to 1, " - "but the value given is %d.", - x_dims.size())); - - ctx->SetOutputDim("RemappedLabel", x_dims); - auto num_samples = ctx->Attrs().Get("num_samples"); - ctx->SetOutputDim("SampledLocalClassCenter", phi::make_ddim({num_samples})); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -144,6 +123,10 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(class_center_sample, + ClassCenterSampleInferShapeFunctor, + PD_INFER_META(phi::ClassCenterSampleInferMeta)); REGISTER_OP_WITHOUT_GRADIENT(class_center_sample, ops::ClassCenterSampleOp, - ops::ClassCenterSampleOpMaker); + ops::ClassCenterSampleOpMaker, + ClassCenterSampleInferShapeFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index b67498bcc1..d58acfd77e 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -455,6 +455,14 @@ func : celu backward : celu_grad +- api : class_center_sample + args : (Tensor label, int num_classes, int num_samples, int ring_id, int rank, int nranks, bool fix_seed, int seed) + output : Tensor(remapped_label), Tensor(sampled_local_class_center) + infer_meta : + func : ClassCenterSampleInferMeta + kernel : + func : class_center_sample + - api : clip args : (Tensor x, Scalar(float) min, Scalar(float) max) output : Tensor(out) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d6395c8a2e..7da162cd0b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -309,6 +309,35 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { out->set_dtype(x.dtype()); } +void ClassCenterSampleInferMeta(const MetaTensor& label, + int num_classes, + int num_samples, + int ring_id, + int rank, + int nranks, + bool fix_seed, + int seed, + MetaTensor* remapped_label, + MetaTensor* sampled_local_class_center) { + PADDLE_ENFORCE_EQ( + label.dims().size(), + 1, + errors::InvalidArgument("Rank of Input(Label) should be equal to 1, " + "but the value given is %d.", + label.dims().size())); + PADDLE_ENFORCE_NOT_NULL(remapped_label, + phi::errors::InvalidArgument( + "output of remapped label should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + sampled_local_class_center, + phi::errors::InvalidArgument( + "output of sampled local class center should not be null.")); + remapped_label->set_dims(label.dims()); + remapped_label->set_dtype(label.dtype()); + sampled_local_class_center->set_dims(phi::make_ddim({num_samples})); + sampled_local_class_center->set_dtype(label.dtype()); +} + void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out) { PADDLE_ENFORCE_GT( max_norm, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index f2bb43e952..d81c8ea7a4 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -67,6 +67,17 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); +void ClassCenterSampleInferMeta(const MetaTensor& label, + int num_classes, + int num_samples, + int ring_id, + int rank, + int nranks, + bool fix_seed, + int seed, + MetaTensor* remapped_label, + MetaTensor* sampled_local_class_center); + void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); diff --git a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py index 492dae47f2..736d31b018 100644 --- a/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_class_center_sample_op.py @@ -56,10 +56,27 @@ def class_center_sample_numpy(label, classes_list, num_samples): return np.array(remapped_label), np.array(pos_class_center_per_device) +def python_api( + label, + num_classes=1, + num_samples=1, + ring_id=0, + rank=0, + nranks=0, + fix_seed=False, + seed=0, +): + return paddle.nn.functional.class_center_sample(label, + num_classes=num_classes, + num_samples=num_samples, + group=None) + + class TestClassCenterSampleOp(OpTest): def initParams(self): self.op_type = "class_center_sample" + self.python_api = python_api self.batch_size = 20 self.num_samples = 6 self.num_classes = 10 @@ -96,7 +113,8 @@ class TestClassCenterSampleOp(OpTest): } def test_check_output(self): - self.check_output(no_check_set=['SampledLocalClassCenter']) + self.check_output(no_check_set=['SampledLocalClassCenter'], + check_eager=True) class TestClassCenterSampleOpINT32(TestClassCenterSampleOp): diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index a7fad9a7c8..7f381a8846 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1958,7 +1958,11 @@ def class_center_sample(label, num_classes, num_samples, group=None): if (seed is None or seed == 0) and default_main_program().random_seed != 0: seed = default_main_program().random_seed - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_class_center_sample( + label, num_classes, num_samples, ring_id, rank, nranks, seed + is not None, seed if seed is not None else 0) + elif paddle.in_dynamic_mode(): remapped_label, sampled_class_center = _C_ops.class_center_sample( label, 'num_classes', num_classes, 'num_samples', num_samples, 'ring_id', ring_id, 'nranks', nranks, 'rank', rank, 'fix_seed', seed -- GitLab