未验证 提交 197f4048 编写于 作者: S Sing_chan 提交者: GitHub

【autograd】add max_p primitive operator for new autograd (#45178)

* add max_p without test

* add test of max_p

* make max_p consistent with paddle.maximum
上级 7c1e7e46
...@@ -27,7 +27,8 @@ set(PRIM_OP_SRCS ...@@ -27,7 +27,8 @@ set(PRIM_OP_SRCS
log_p_op.cc log_p_op.cc
select_p_op.cc select_p_op.cc
eq_p_op.cc eq_p_op.cc
pow_p_op.cc) pow_p_op.cc
max_p_op.cc)
cc_test( cc_test(
prim_op_test prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class MaxPrimOp : public framework::OperatorBase {
public:
MaxPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator max_p should not be excuted directly"));
}
};
class MaxPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of max_p op.");
AddInput("Y", "(Tensor), The input tensor of max_p op.");
AddOutput("Z", "(Tensor), The output tensor of max_p op.");
AddComment(R"DOC(
Autograd primitive max_p operator.
)DOC");
}
};
class MaxPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class MaxPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(max_p,
paddle::operators::MaxPrimOp,
paddle::operators::MaxPrimOpMaker,
paddle::operators::MaxPrimOpShapeInference,
paddle::operators::MaxPrimOpVarTypeInference);
...@@ -38,6 +38,7 @@ USE_OP_ITSELF(log_p); ...@@ -38,6 +38,7 @@ USE_OP_ITSELF(log_p);
USE_OP_ITSELF(select_p); USE_OP_ITSELF(select_p);
USE_OP_ITSELF(eq_p); USE_OP_ITSELF(eq_p);
USE_OP_ITSELF(pow_p); USE_OP_ITSELF(pow_p);
USE_OP_ITSELF(max_p);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -687,5 +688,27 @@ TEST(PrimOp, pow_p) { ...@@ -687,5 +688,27 @@ TEST(PrimOp, pow_p) {
ASSERT_EQ(shapes[2], 5L); ASSERT_EQ(shapes[2], 5L);
} }
TEST(PrimOp, max_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{2, 3, 4};
std::string x = "x";
std::string y = "y";
std::string z = "z";
NewVar(block, x, shape);
NewVar(block, y, shape);
AppendOp(block, "max_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {});
ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("z")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 2L);
ASSERT_EQ(shapes[1], 3L);
ASSERT_EQ(shapes[2], 4L);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -982,5 +982,36 @@ class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose): ...@@ -982,5 +982,36 @@ class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose):
] ]
class TestMaxPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'max_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='float32')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'max_p',
# jvp op:
'fill_constant_p',
'eq_p',
'select_p',
# transpose op:
]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -539,5 +539,25 @@ class TestPowOrig2Prim(TestElementWiseAddOrig2Prim): ...@@ -539,5 +539,25 @@ class TestPowOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']} self.out_map = {0: self.output['Out']}
class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_max'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_max', 'max_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -554,5 +554,24 @@ class TestPowPPrim2Orig(TestAddPPrim2Orig): ...@@ -554,5 +554,24 @@ class TestPowPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0} self.out_map = {self.output['Z']: 0}
class TestMaxPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'max_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['max_p', 'elementwise_max']
self.out_map = {self.output['Z']: 0}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -151,7 +151,7 @@ class TestWithoutProgramGuard(unittest.TestCase): ...@@ -151,7 +151,7 @@ class TestWithoutProgramGuard(unittest.TestCase):
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
)) ))
# paddle.where, paddle.pow has no double grad definition, # paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick # can not compute forward grad use double trick
class TestForwardGrad(unittest.TestCase): class TestForwardGrad(unittest.TestCase):
...@@ -273,6 +273,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) ...@@ -273,6 +273,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
# pow_p and pow has diff when compute z_dot of 0^0 # pow_p and pow has diff when compute z_dot of 0^0
('pow', paddle.pow, ('pow', paddle.pow,
(np.array([1, 2, 3]), np.array([0, 2, 7])), None, 'float32'), (np.array([1, 2, 3]), np.array([0, 2, 7])), None, 'float32'),
# To make max_p consistent with paddle.maximum, be sure x.grad = 0 and y.grad = 1 when x==y.
('max', paddle.maximum, (
np.array([1, 2, 3]),
np.array([2, 2, 2]),
), None, 'float32'),
)) ))
class TestGrad(unittest.TestCase): class TestGrad(unittest.TestCase):
......
...@@ -99,6 +99,7 @@ paddle.enable_static() ...@@ -99,6 +99,7 @@ paddle.enable_static()
(randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), (randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
)) ))
class TestPrimops(unittest.TestCase): class TestPrimops(unittest.TestCase):
......
...@@ -350,3 +350,8 @@ def eq(x, y, out=None): ...@@ -350,3 +350,8 @@ def eq(x, y, out=None):
@REGISTER_FN('pow_p', 'X', 'Y', 'Z') @REGISTER_FN('pow_p', 'X', 'Y', 'Z')
def pow(x, y, out=None): def pow(x, y, out=None):
return _simple_binop(LayerHelper('pow_p', **locals())) return _simple_binop(LayerHelper('pow_p', **locals()))
@REGISTER_FN('max_p', 'X', 'Y', 'Z')
def max(x, y, out=None):
return _simple_binop(LayerHelper('max_p', **locals()))
...@@ -19,7 +19,7 @@ from . import primops ...@@ -19,7 +19,7 @@ from . import primops
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value, matmul, mul, neg, reduce, reshape, scatter_add, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh, sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, log, select, eq) transpose, log, select, eq, max)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose, lookup_orig2prim, lookup_prim2orig, lookup_transpose,
...@@ -317,6 +317,14 @@ def elementwise_pow_orig2prim(op, x, y): ...@@ -317,6 +317,14 @@ def elementwise_pow_orig2prim(op, x, y):
return z return z
@REGISTER_ORIG2PRIM('elementwise_max')
def elementwise_max_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return primops.max(x, y)
## Register prim2orig lower rules ## Register prim2orig lower rules
...@@ -466,6 +474,11 @@ def pow_prim2orig(op, x, y): ...@@ -466,6 +474,11 @@ def pow_prim2orig(op, x, y):
return paddle.pow(x, y) return paddle.pow(x, y)
@REGISTER_PRIM2ORIG('max_p')
def max_prim2orig(op, x, y):
return paddle.maximum(x, y)
## Register linearize rules ## Register linearize rules
@REGISTER_JVP('add_p') @REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot): def add_jvp(op, x_dot, y_dot):
...@@ -737,6 +750,26 @@ def pow_jvp(op, x_dot, y_dot): ...@@ -737,6 +750,26 @@ def pow_jvp(op, x_dot, y_dot):
return z_dot return z_dot
@REGISTER_JVP('max_p')
def max_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, y = op_position_inputs(op)
z = op_position_output(op)
z_zeros = fill_const(value=0.0, shape=z.shape, dtype=z.dtype)
# To make the grad of max_p consistent with paddle.maximum when x==y,
# we just let z_dot = y_dot when compute z_dot to y and x==y,
# instead of using balance_eq like Jax.
if y_dot is None:
return select(eq(y, z), z_zeros, x_dot)
elif x_dot is None:
return select(eq(y, z), y_dot, z_zeros)
else:
return select(eq(y, z), y_dot, x_dot)
## Register transpose rules ## Register transpose rules
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册