未验证 提交 463fc15e 编写于 作者: S Sing_chan 提交者: GitHub

【autograd】add log_p primitive operator for new autograd (#44779)

* add log_p for auto_grad

* add log_p_op.cc in prim_op_test srcs

* fix bug of wrong op name; add test in test_primops

* add test case of log in testprimapi

* fix bug of test_without_guard

* no need to fix test_without_guard
上级 ad716551
...@@ -23,7 +23,8 @@ set(PRIM_OP_SRCS ...@@ -23,7 +23,8 @@ set(PRIM_OP_SRCS
sqrt_p_op.cc sqrt_p_op.cc
tanh_p_op.cc tanh_p_op.cc
matmul_p_op.cc matmul_p_op.cc
fill_constant_p_op.cc) fill_constant_p_op.cc
log_p_op.cc)
cc_test( cc_test(
prim_op_test prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class LogPrimOp : public framework::OperatorBase {
public:
LogPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator log_p should not be excuted directly"));
}
};
class LogPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of log_p op.");
AddOutput("Y", "(Tensor), The output tensor of log_p op.");
AddComment(R"DOC(
Autograd primitive log_p operator.
)DOC");
}
};
class LogPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class LogPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(log_p,
paddle::operators::LogPrimOp,
paddle::operators::LogPrimOpMaker,
paddle::operators::LogPrimOpShapeInference,
paddle::operators::LogPrimOpVarTypeInference);
...@@ -34,6 +34,7 @@ USE_OP_ITSELF(sqrt_p); ...@@ -34,6 +34,7 @@ USE_OP_ITSELF(sqrt_p);
USE_OP_ITSELF(tanh_p); USE_OP_ITSELF(tanh_p);
USE_OP_ITSELF(matmul_p); USE_OP_ITSELF(matmul_p);
USE_OP_ITSELF(fill_constant_p); USE_OP_ITSELF(fill_constant_p);
USE_OP_ITSELF(log_p);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -595,5 +596,24 @@ TEST(PrimOp, fill_constant_p) { ...@@ -595,5 +596,24 @@ TEST(PrimOp, fill_constant_p) {
ASSERT_EQ(shapes[2], 5L); ASSERT_EQ(shapes[2], 5L);
} }
TEST(PrimOp, log_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "log_p", {{"X", {x0}}}, {{"Y", {x1}}}, {});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -364,6 +364,35 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose): ...@@ -364,6 +364,35 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose):
] ]
class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'log_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'log_p',
# jvp op:
'div_p',
# transpose op:
]
class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose): class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self): def init_data(self):
......
...@@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim): ...@@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']} self.out_map = {0: self.output['Out']}
class TestLogOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'log'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['log', 'log_p']
self.out_map = {0: self.output['Out']}
class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim): class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self): def init_data(self):
......
...@@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig): ...@@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0} self.out_map = {self.output['Y']: 0}
class TestLogPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'log_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['log_p', 'log']
self.out_map = {self.output['Y']: 0}
class TestReshapePPrim2Orig(TestAddPPrim2Orig): class TestReshapePPrim2Orig(TestAddPPrim2Orig):
def init_data(self): def init_data(self):
......
...@@ -146,6 +146,7 @@ class TestWithoutProgramGuard(unittest.TestCase): ...@@ -146,6 +146,7 @@ class TestWithoutProgramGuard(unittest.TestCase):
('input_gradients_not_none', paddle.matmul, ('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)), (np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
)) ))
class TestForwardGrad(unittest.TestCase): class TestForwardGrad(unittest.TestCase):
...@@ -254,6 +255,7 @@ class TestForwardGrad(unittest.TestCase): ...@@ -254,6 +255,7 @@ class TestForwardGrad(unittest.TestCase):
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
)) ))
class TestGrad(unittest.TestCase): class TestGrad(unittest.TestCase):
......
...@@ -41,6 +41,7 @@ paddle.enable_static() ...@@ -41,6 +41,7 @@ paddle.enable_static()
('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'), ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'),
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('reshape', primops.reshape, randn(2, 3), { ('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2) 'shape': (3, 2)
}, (3, 2), 'float64'), }, (3, 2), 'float64'),
......
...@@ -137,6 +137,11 @@ def exp(x, out=None): ...@@ -137,6 +137,11 @@ def exp(x, out=None):
return _simple_unop(LayerHelper('exp_p', **locals())) return _simple_unop(LayerHelper('exp_p', **locals()))
@REGISTER_FN('log_p', 'X', 'Y')
def log(x, out=None):
return _simple_unop(LayerHelper('log_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y') @REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None): def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals())) return _manipulation_unop(LayerHelper('reshape_p', **locals()))
......
...@@ -18,7 +18,7 @@ import paddle ...@@ -18,7 +18,7 @@ import paddle
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value, matmul, mul, neg, reduce, reshape, scatter_add, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh, sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose) transpose, log)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose, lookup_orig2prim, lookup_prim2orig, lookup_transpose,
...@@ -166,6 +166,11 @@ def exp_orig2prim(op, x): ...@@ -166,6 +166,11 @@ def exp_orig2prim(op, x):
return exp(x) return exp(x)
@REGISTER_ORIG2PRIM('log')
def log_orig2prim(op, x):
return log(x)
@REGISTER_ORIG2PRIM('fill_zeros_like') @REGISTER_ORIG2PRIM('fill_zeros_like')
def fill_zeros_like_orig2prim(op, x): def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
...@@ -333,6 +338,11 @@ def exp_prim2orig(op, x): ...@@ -333,6 +338,11 @@ def exp_prim2orig(op, x):
return paddle.exp(x) return paddle.exp(x)
@REGISTER_PRIM2ORIG('log_p')
def log_prim2orig(op, x):
return paddle.log(x)
@REGISTER_PRIM2ORIG('reshape_p') @REGISTER_PRIM2ORIG('reshape_p')
def reshape_prim2orig(op, x): def reshape_prim2orig(op, x):
return paddle.reshape(x, shape=op.attr('shape')) return paddle.reshape(x, shape=op.attr('shape'))
...@@ -509,6 +519,14 @@ def exp_jvp(op, x_dot): ...@@ -509,6 +519,14 @@ def exp_jvp(op, x_dot):
return mul(x_dot, y) return mul(x_dot, y)
@REGISTER_JVP('log_p')
def log_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return div(x_dot, x)
@REGISTER_JVP('reshape_p') @REGISTER_JVP('reshape_p')
def reshape_jvp(op, x_dot): def reshape_jvp(op, x_dot):
if x_dot is None: if x_dot is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册