未验证 提交 b1e33bea 编写于 作者: D duanboqiang 提交者: GitHub

[phi] migration of class center sample infermeta (#45025)

* add class center sample infershape

* add yaml

* modify unittest

* modify unittest

* remove comment
上级 9e74211f
...@@ -12,8 +12,11 @@ ...@@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,30 +24,6 @@ namespace operators { ...@@ -21,30 +24,6 @@ namespace operators {
class ClassCenterSampleOp : public framework::OperatorWithKernel { class ClassCenterSampleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Label"), "Input", "Label", "ClassCenterSample");
OP_INOUT_CHECK(ctx->HasOutput("RemappedLabel"),
"Output",
"RemappedLabel",
"ClassCenterSample");
OP_INOUT_CHECK(ctx->HasOutput("SampledLocalClassCenter"),
"Output",
"SampledLocalClassCenter",
"ClassCenterSample");
auto x_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(x_dims.size(),
1,
platform::errors::InvalidArgument(
"Rank of Input(Label) should be equal to 1, "
"but the value given is %d.",
x_dims.size()));
ctx->SetOutputDim("RemappedLabel", x_dims);
auto num_samples = ctx->Attrs().Get<int>("num_samples");
ctx->SetOutputDim("SampledLocalClassCenter", phi::make_ddim({num_samples}));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -144,6 +123,10 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -144,6 +123,10 @@ class ClassCenterSampleOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(class_center_sample,
ClassCenterSampleInferShapeFunctor,
PD_INFER_META(phi::ClassCenterSampleInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(class_center_sample, REGISTER_OP_WITHOUT_GRADIENT(class_center_sample,
ops::ClassCenterSampleOp, ops::ClassCenterSampleOp,
ops::ClassCenterSampleOpMaker); ops::ClassCenterSampleOpMaker,
ClassCenterSampleInferShapeFunctor);
...@@ -455,6 +455,14 @@ ...@@ -455,6 +455,14 @@
func : celu func : celu
backward : celu_grad backward : celu_grad
- api : class_center_sample
args : (Tensor label, int num_classes, int num_samples, int ring_id, int rank, int nranks, bool fix_seed, int seed)
output : Tensor(remapped_label), Tensor(sampled_local_class_center)
infer_meta :
func : ClassCenterSampleInferMeta
kernel :
func : class_center_sample
- api : clip - api : clip
args : (Tensor x, Scalar(float) min, Scalar(float) max) args : (Tensor x, Scalar(float) min, Scalar(float) max)
output : Tensor(out) output : Tensor(out)
......
...@@ -309,6 +309,35 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { ...@@ -309,6 +309,35 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void ClassCenterSampleInferMeta(const MetaTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
MetaTensor* remapped_label,
MetaTensor* sampled_local_class_center) {
PADDLE_ENFORCE_EQ(
label.dims().size(),
1,
errors::InvalidArgument("Rank of Input(Label) should be equal to 1, "
"but the value given is %d.",
label.dims().size()));
PADDLE_ENFORCE_NOT_NULL(remapped_label,
phi::errors::InvalidArgument(
"output of remapped label should not be null."));
PADDLE_ENFORCE_NOT_NULL(
sampled_local_class_center,
phi::errors::InvalidArgument(
"output of sampled local class center should not be null."));
remapped_label->set_dims(label.dims());
remapped_label->set_dtype(label.dtype());
sampled_local_class_center->set_dims(phi::make_ddim({num_samples}));
sampled_local_class_center->set_dtype(label.dtype());
}
void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out) { void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
max_norm, max_norm,
......
...@@ -67,6 +67,17 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); ...@@ -67,6 +67,17 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
void ClassCenterSampleInferMeta(const MetaTensor& label,
int num_classes,
int num_samples,
int ring_id,
int rank,
int nranks,
bool fix_seed,
int seed,
MetaTensor* remapped_label,
MetaTensor* sampled_local_class_center);
void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out); void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
......
...@@ -56,10 +56,27 @@ def class_center_sample_numpy(label, classes_list, num_samples): ...@@ -56,10 +56,27 @@ def class_center_sample_numpy(label, classes_list, num_samples):
return np.array(remapped_label), np.array(pos_class_center_per_device) return np.array(remapped_label), np.array(pos_class_center_per_device)
def python_api(
label,
num_classes=1,
num_samples=1,
ring_id=0,
rank=0,
nranks=0,
fix_seed=False,
seed=0,
):
return paddle.nn.functional.class_center_sample(label,
num_classes=num_classes,
num_samples=num_samples,
group=None)
class TestClassCenterSampleOp(OpTest): class TestClassCenterSampleOp(OpTest):
def initParams(self): def initParams(self):
self.op_type = "class_center_sample" self.op_type = "class_center_sample"
self.python_api = python_api
self.batch_size = 20 self.batch_size = 20
self.num_samples = 6 self.num_samples = 6
self.num_classes = 10 self.num_classes = 10
...@@ -96,7 +113,8 @@ class TestClassCenterSampleOp(OpTest): ...@@ -96,7 +113,8 @@ class TestClassCenterSampleOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['SampledLocalClassCenter']) self.check_output(no_check_set=['SampledLocalClassCenter'],
check_eager=True)
class TestClassCenterSampleOpINT32(TestClassCenterSampleOp): class TestClassCenterSampleOpINT32(TestClassCenterSampleOp):
......
...@@ -1958,7 +1958,11 @@ def class_center_sample(label, num_classes, num_samples, group=None): ...@@ -1958,7 +1958,11 @@ def class_center_sample(label, num_classes, num_samples, group=None):
if (seed is None or seed == 0) and default_main_program().random_seed != 0: if (seed is None or seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed seed = default_main_program().random_seed
if in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_class_center_sample(
label, num_classes, num_samples, ring_id, rank, nranks, seed
is not None, seed if seed is not None else 0)
elif paddle.in_dynamic_mode():
remapped_label, sampled_class_center = _C_ops.class_center_sample( remapped_label, sampled_class_center = _C_ops.class_center_sample(
label, 'num_classes', num_classes, 'num_samples', num_samples, label, 'num_classes', num_classes, 'num_samples', num_samples,
'ring_id', ring_id, 'nranks', nranks, 'rank', rank, 'fix_seed', seed 'ring_id', ring_id, 'nranks', nranks, 'rank', rank, 'fix_seed', seed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册