diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 5bbf4fbc616d9266c0c3ee3539d09ef19c2ca29d..1651dae2d0f045f74908a2095f6c59646329894b 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -27,7 +27,8 @@ set(PRIM_OP_SRCS log_p_op.cc select_p_op.cc eq_p_op.cc - pow_p_op.cc) + pow_p_op.cc + max_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/max_p_op.cc b/paddle/fluid/operators/prim_ops/max_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..786b57dbe823c892dbfdd9e6979a99b015111835 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/max_p_op.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class MaxPrimOp : public framework::OperatorBase { + public: + MaxPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator max_p should not be excuted directly")); + } +}; + +class MaxPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of max_p op."); + AddInput("Y", "(Tensor), The input tensor of max_p op."); + AddOutput("Z", "(Tensor), The output tensor of max_p op."); + AddComment(R"DOC( +Autograd primitive max_p operator. +)DOC"); + } +}; + +class MaxPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class MaxPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(max_p, + paddle::operators::MaxPrimOp, + paddle::operators::MaxPrimOpMaker, + paddle::operators::MaxPrimOpShapeInference, + paddle::operators::MaxPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index 7f2f07cf1a66264a8c0d4cf7811c4180fadb5597..f3a74138abbd6919d2e06446e25a74f2f555cbab 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -38,6 +38,7 @@ USE_OP_ITSELF(log_p); USE_OP_ITSELF(select_p); USE_OP_ITSELF(eq_p); USE_OP_ITSELF(pow_p); +USE_OP_ITSELF(max_p); namespace paddle { namespace framework { @@ -687,5 +688,27 @@ TEST(PrimOp, pow_p) { ASSERT_EQ(shapes[2], 5L); } +TEST(PrimOp, max_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{2, 3, 4}; + + std::string x = "x"; + std::string y = "y"; + std::string z = "z"; + + NewVar(block, x, shape); + NewVar(block, y, shape); + + AppendOp(block, "max_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {}); + ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("z")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 2L); + ASSERT_EQ(shapes[1], 3L); + ASSERT_EQ(shapes[2], 4L); +} + } // namespace framework } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index c09e5bf86480bd8e823ac3d68773e78b0af62a86..2c23da54970324dff50ba92927c72000871e2c90 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -982,5 +982,36 @@ class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestMaxPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'max_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'max_p', + # jvp op: + 'fill_constant_p', + 'eq_p', + 'select_p', + # transpose op: + ] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 87f949034152e91adf0db271c7179e87babf11b9..ce9c64fbbed78250a29c3da566d5e06f71c128b4 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -539,5 +539,25 @@ class TestPowOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'elementwise_max' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_max', 'max_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 8d23ddad1d2b92889315b0815d59eb16df48fd75..53120ce742aff43b084b7568f4d659459d98a581 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -554,5 +554,24 @@ class TestPowPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Z']: 0} +class TestMaxPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'max_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['max_p', 'elementwise_max'] + self.out_map = {self.output['Z']: 0} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 04610ce2c7dc78dda05ddd0973a738af07278eb1..5bd21e419045929f2496526892501ba0dc31adc5 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -151,7 +151,7 @@ class TestWithoutProgramGuard(unittest.TestCase): (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), )) -# paddle.where, paddle.pow has no double grad definition, +# paddle.where, paddle.pow, paddle.maximum has no double grad definition, # can not compute forward grad use double trick class TestForwardGrad(unittest.TestCase): @@ -273,6 +273,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) # pow_p and pow has diff when compute z_dot of 0^0 ('pow', paddle.pow, (np.array([1, 2, 3]), np.array([0, 2, 7])), None, 'float32'), + # To make max_p consistent with paddle.maximum, be sure x.grad = 0 and y.grad = 1 when x==y. + ('max', paddle.maximum, ( + np.array([1, 2, 3]), + np.array([2, 2, 2]), + ), None, 'float32'), )) class TestGrad(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index 79e9326a8cc343ef3cd651c9f0c573e15013cfa5..f1396ce69f9e05269b58960500e96785d432fa38 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -99,6 +99,7 @@ paddle.enable_static() (randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), )) class TestPrimops(unittest.TestCase): diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index bd48a86fe496f453cc252001af77a5f2d97a7d45..d29647f9404bf107054aba079755525397dd8d07 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -350,3 +350,8 @@ def eq(x, y, out=None): @REGISTER_FN('pow_p', 'X', 'Y', 'Z') def pow(x, y, out=None): return _simple_binop(LayerHelper('pow_p', **locals())) + + +@REGISTER_FN('max_p', 'X', 'Y', 'Z') +def max(x, y, out=None): + return _simple_binop(LayerHelper('max_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 3795bffae0d63850be1f88f1483a765c608bf7a2..bfcfcfb9a4fae645ea88b8f4f30bbe61c739e7a9 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -19,7 +19,7 @@ from . import primops from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, matmul, mul, neg, reduce, reshape, scatter_add, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose, log, select, eq) + transpose, log, select, eq, max) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -317,6 +317,14 @@ def elementwise_pow_orig2prim(op, x, y): return z +@REGISTER_ORIG2PRIM('elementwise_max') +def elementwise_max_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + + return primops.max(x, y) + + ## Register prim2orig lower rules @@ -466,6 +474,11 @@ def pow_prim2orig(op, x, y): return paddle.pow(x, y) +@REGISTER_PRIM2ORIG('max_p') +def max_prim2orig(op, x, y): + return paddle.maximum(x, y) + + ## Register linearize rules @REGISTER_JVP('add_p') def add_jvp(op, x_dot, y_dot): @@ -737,6 +750,26 @@ def pow_jvp(op, x_dot, y_dot): return z_dot +@REGISTER_JVP('max_p') +def max_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + + x, y = op_position_inputs(op) + z = op_position_output(op) + z_zeros = fill_const(value=0.0, shape=z.shape, dtype=z.dtype) + + # To make the grad of max_p consistent with paddle.maximum when x==y, + # we just let z_dot = y_dot when compute z_dot to y and x==y, + # instead of using balance_eq like Jax. + if y_dot is None: + return select(eq(y, z), z_zeros, x_dot) + elif x_dot is None: + return select(eq(y, z), y_dot, z_zeros) + else: + return select(eq(y, z), y_dot, x_dot) + + ## Register transpose rules