From 1711407d40608d5642a98ffa62bcc59b373b7152 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Fri, 16 Sep 2022 14:21:29 +0800 Subject: [PATCH] support pow with scalar input, square, cast, var, size operators for deepxde (#46024) * add reduce_mean,reduce_sum primitive ops * add ne_p gt_p primitive operators * add ge_p abs_p primitive oparators * add cast primitive operators * add pow,square prim2oirg rules * add elementwise_div orig2prim rule --- .../fluid/operators/prim_ops/CMakeLists.txt | 3 +- paddle/fluid/operators/prim_ops/cast_p_op.cc | 78 ++++++++++++++++ .../autograd/test_jvp_and_transpose.py | 36 +++++++ .../unittests/autograd/test_orig2prim.py | 93 +++++++++++++++++++ .../unittests/autograd/test_prim2orig.py | 20 ++++ .../tests/unittests/autograd/test_primapi.py | 18 +++- .../tests/unittests/autograd/test_primops.py | 3 + python/paddle/incubate/autograd/primops.py | 12 +++ python/paddle/incubate/autograd/primreg.py | 2 +- python/paddle/incubate/autograd/primrules.py | 49 +++++++++- 10 files changed, 310 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/prim_ops/cast_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 9d24cf89af..30e162a4dd 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -36,7 +36,8 @@ set(PRIM_OP_SRCS pow_p_op.cc max_p_op.cc erf_p_op.cc - abs_p_op.cc) + abs_p_op.cc + cast_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/cast_p_op.cc b/paddle/fluid/operators/prim_ops/cast_p_op.cc new file mode 100644 index 0000000000..5c8b9ab45c --- /dev/null +++ b/paddle/fluid/operators/prim_ops/cast_p_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class CastPrimOp : public framework::OperatorBase { + public: + CastPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator cast_p should not be excuted directly")); + } +}; + +class CastPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of cast_p op."); + AddOutput("Y", "(Tensor), The output tensor of cast_p op."); + AddAttr("dtype", "output data type"); + AddComment(R"DOC(Autograd primitive cast_p operator.)DOC"); + } +}; + +class CastPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class CastPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_type = static_cast( + PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); + ctx->SetOutputDataType("Y", out_type); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(cast_p, + paddle::operators::CastPrimOp, + paddle::operators::CastPrimOpMaker, + paddle::operators::CastPrimOpShapeInference, + paddle::operators::CastPrimOpVarTypeInference); diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index 76698a7a8b..c8e1b39652 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -433,6 +433,42 @@ class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestCastPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'cast_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'dtype': paddle.float64} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: True + Y_BAR = paddle.static.data(name='Y_BAR', shape=[5, 6], dtype='float') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X} + + self.all_ops = [ + # prim op: + 'cast_p', + # jvp op: + 'cast_p', + # transpose op: + 'cast_p' + ] + + class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 92a50d8bb1..e1d5ee11a1 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -110,6 +110,26 @@ class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestElementWiseDivOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'elementwise_div' + X = paddle.static.data(name='X', shape=[8, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_div', 'div_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): @@ -786,5 +806,78 @@ class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestSizeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'size' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'Input': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.int64) + } + self.attrs = {} + self.orig2prim_args = (X, ) + self.all_ops = ['size', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestCastOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'cast' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'in_dtype': X.dtype, 'out_dtype': paddle.float64} + self.orig2prim_args = (X, ) + self.all_ops = ['cast', 'cast_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestPowScalarOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'pow' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'factor': 2.} + self.orig2prim_args = (None, X) + self.all_ops = ['pow', 'pow_p', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'square' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + self.orig2prim_args = (X, ) + self.all_ops = ['square', 'pow_p', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index c173cc4790..a89b91bdd2 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -670,5 +670,25 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Z']: 0} +class TestCastPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'cast_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dtype': paddle.int64} + + self.prim2orig_args = (X, ) + self.all_ops = ['cast_p', 'cast'] + self.out_map = {self.output['Y']: 0} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index d010e69e75..bdc54563fc 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -257,6 +257,8 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), ('multiply', paddle.multiply, (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('div', paddle.divide, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), ('add', paddle.add, (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), ('input_not_sequence', paddle.tanh, @@ -300,7 +302,21 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) (np.random.rand(200, 345), ), None, 'float32'), ('abs', paddle.abs, (np.random.uniform(-10, 10, (200, 345)), ), None, 'float32'), - )) + ('cast_float', lambda x: paddle.cast(x, paddle.float64), + (np.random.rand(10, 20), ), None, 'float32'), + ('cast_int', lambda x: paddle.cast(x, paddle.int32), + (np.random.rand(10, 20), ), None, 'float32'), + ('square', paddle.square, (np.random.rand(100), ), None, 'float32'), + ('pow_scalar', lambda x: paddle.pow(x, 2), + (np.random.rand(20, 30), ), None, 'float32'), + ('var', paddle.var, (np.random.rand(200, 324), ), None, 'float32'), + ('var_with_axis', lambda x: paddle.var(x, axis=1), + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('var_without_unbiased', + lambda x: paddle.var(x, axis=1, unbiased=False), + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True), + (np.random.rand(10, 20, 30), ), None, 'float32'))) class TestGrad(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index ba6f094e68..35291432f6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -44,6 +44,9 @@ paddle.enable_static() ('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'), ('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'), ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'), + ('cast', primops.cast, randn(2, 3), { + 'dtype': paddle.int64 + }, (2, 3), 'int64'), ('reshape', primops.reshape, randn(2, 3), { 'shape': (3, 2) }, (3, 2), 'float64'), diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index dde3fb492c..636dc89220 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -382,3 +382,15 @@ def max(x, y, out=None): @REGISTER_FN('erf_p', 'X', 'Y') def erf(x, out=None): return _simple_unop(LayerHelper('erf_p', **locals())) + + +@REGISTER_FN('cast_p', 'X', 'Y') +def cast(x, dtype, out=None): + helper = LayerHelper('cast_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs={'dtype': dtype}) + return out diff --git a/python/paddle/incubate/autograd/primreg.py b/python/paddle/incubate/autograd/primreg.py index 4721500b2b..7972409d93 100644 --- a/python/paddle/incubate/autograd/primreg.py +++ b/python/paddle/incubate/autograd/primreg.py @@ -80,7 +80,7 @@ def op_position_inputs(op): """ args = _primop_position_argnames.lookup(op.type) - assert args is not None, 'args should not be None in op_position_inputs().' + assert args is not None, f'args of {op.type} should not be None in op_position_inputs().' *input_names, _ = args inputs = [] diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 954bdf0cb1..4625cfd362 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -158,6 +158,13 @@ def elementwise_mul_orig2prim(op, x, y): return z +@REGISTER_ORIG2PRIM('elementwise_div') +def elementwise_div_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return primops.div(x, y) + + @REGISTER_ORIG2PRIM('tanh') def tanh_orig2prim(op, x): return tanh(x) @@ -322,6 +329,11 @@ def p_norm_orig2prim(op, x): raise RuntimeError('Only support lower l2/l1 norm currently') +@REGISTER_ORIG2PRIM('cast') +def cast_orig2prim(op, x): + return primops.cast(x, paddle.dtype(op.attr('out_dtype'))) + + # TODO: support broadcast @REGISTER_ORIG2PRIM('where') def select_orig2prim(op, condition, x, y): @@ -356,15 +368,27 @@ def ge_orig2prim(op, x, y): return ge(x, y) +# paddle.pow API use "elementwise_pow" operator when y is a Tensor. @REGISTER_ORIG2PRIM('elementwise_pow') def elementwise_pow_orig2prim(op, x, y): if x.shape != y.shape: y = broadcast(y, shape=x.shape) - z = primops.pow(x, y) return z +# paddle.pow API use "pow" operator when y is a scalar. +@REGISTER_ORIG2PRIM('pow') +def pow_orig2prim(op, x, y): + # x is factorTensor defined in paddle phi op. Currently it is None. + return primops.pow(y, fill_const(op.attr('factor'), y.shape, y.dtype)) + + +@REGISTER_ORIG2PRIM('square') +def square_orig2prim(op, x): + return primops.pow(x, fill_const(2., x.shape, x.dtype)) + + @REGISTER_ORIG2PRIM('elementwise_max') def elementwise_max_orig2prim(op, x, y): if x.shape != y.shape: @@ -415,6 +439,12 @@ def reduce_mean_orig2prim(op, x): return div(sum, norm) +@REGISTER_ORIG2PRIM('size') +def size_orig2prim(op, x): + return fill_const(functools.reduce(operator.mul, x.shape), (1, ), + paddle.int64) + + ## Register prim2orig lower rules @REGISTER_PRIM2ORIG('add_p') def add_prim2orig(op, x, y): @@ -592,6 +622,11 @@ def max_prim2orig(op, x, y): return paddle.maximum(x, y) +@REGISTER_PRIM2ORIG('cast_p') +def cast_prim2orig(op, x): + return paddle.cast(x, paddle.dtype(op.attr('dtype'))) + + ## Register linearize rules @REGISTER_JVP('add_p') def add_jvp(op, x_dot, y_dot): @@ -928,6 +963,12 @@ def max_jvp(op, x_dot, y_dot): return select(eq(y, z), y_dot, x_dot) +@REGISTER_JVP('cast_p') +def cast_jvp(op, x_dot): + y = op_position_output(op) + return primops.cast(x_dot, y.dtype) + + ## Register transpose rules @@ -1132,3 +1173,9 @@ def select_transpose(op, check_dot, z_bar): y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None return cond_bar, x_bar, y_bar + + +@REGISTER_TRANSPOSE('cast_p') +def cast_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + return primops.cast(y_bar, x.dtype) -- GitLab