未验证 提交 b681c88c 编写于 作者: S Sing_chan 提交者: GitHub

【autograd】add select_p、eq_p、pow_p primitive operator for new autograd (#44813)

* add select_p

* fix bugs

* add custom test for select_p; modify select_p primrules

* modify according to xiaoxu's comment

* add eq_p, select_p, pow_p, use autograd to test grad of high order

* add requirement of autograd, modify expected type of eq

* modify according to xiaoxu's comment

* import primops to use primops.pow
上级 307801d5
......@@ -24,7 +24,10 @@ set(PRIM_OP_SRCS
tanh_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc
log_p_op.cc)
log_p_op.cc
select_p_op.cc
eq_p_op.cc
pow_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class EqPrimOp : public framework::OperatorBase {
public:
EqPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator eq_p should not be excuted directly"));
}
};
class EqPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of eq_p op.");
AddInput("Y", "(Tensor), The input tensor of eq_p op.");
AddOutput("Z", "(Tensor), The output tensor of eq_p op.");
AddComment(R"DOC(
Autograd primitive eq_p operator.
)DOC");
}
};
class EqPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class EqPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(eq_p,
paddle::operators::EqPrimOp,
paddle::operators::EqPrimOpMaker,
paddle::operators::EqPrimOpShapeInference,
paddle::operators::EqPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class PowPrimOp : public framework::OperatorBase {
public:
PowPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator pow_p should not be excuted directly"));
}
};
class PowPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The base of pow_p op.");
AddInput("Y", "(Tensor), The exponents of pow_p op.");
AddOutput("Z", "(Tensor), The output tensor of pow_p op.");
AddComment(R"DOC(
Autograd primitive pow_p operator.
)DOC");
}
};
class PowPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class PowPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(pow_p,
paddle::operators::PowPrimOp,
paddle::operators::PowPrimOpMaker,
paddle::operators::PowPrimOpShapeInference,
paddle::operators::PowPrimOpVarTypeInference);
......@@ -35,6 +35,9 @@ USE_OP_ITSELF(tanh_p);
USE_OP_ITSELF(matmul_p);
USE_OP_ITSELF(fill_constant_p);
USE_OP_ITSELF(log_p);
USE_OP_ITSELF(select_p);
USE_OP_ITSELF(eq_p);
USE_OP_ITSELF(pow_p);
namespace paddle {
namespace framework {
......@@ -615,5 +618,74 @@ TEST(PrimOp, log_p) {
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, select_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{2, 3};
std::string cond = "cond";
std::string x = "x";
std::string y = "y";
std::string z = "z";
NewVar(block, cond, shape);
NewVar(block, x, shape);
NewVar(block, y, shape);
AppendOp(block,
"select_p",
{{"Condition", {cond}}, {"X", {x}}, {"Y", {y}}},
{{"Z", {z}}},
{});
ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("z")->GetShape();
ASSERT_EQ(shapes.size(), 2UL);
ASSERT_EQ(shapes[0], 2L);
ASSERT_EQ(shapes[1], 3L);
}
TEST(PrimOp, eq_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x = "x";
std::string y = "y";
std::string z = "z";
NewVar(block, x, shape);
NewVar(block, y, shape);
AppendOp(block, "eq_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {});
ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType::BOOL);
auto shapes = block->Var("z")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, pow_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x = "x";
std::string y = "y";
std::string z = "z";
NewVar(block, x, shape);
NewVar(block, y, shape);
AppendOp(block, "pow_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {});
ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("z")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class SelectPrimOp : public framework::OperatorBase {
public:
SelectPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator select_p should not be excuted directly"));
}
};
class SelectPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Condition", "(Tensor), The input condition of select_p op.");
AddInput("X", "(Tensor), The input tensor of select_p op.");
AddInput("Y", "(Tensor), The input tensor of select_p op.");
AddOutput("Z", "(Tensor), The output tensor of select_p op.");
AddComment(R"DOC(
Autograd primitive select_p operator.
)DOC");
}
};
class SelectPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr condition_var_ptr =
ctx->GetInputVarPtrs("Condition")[0];
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *condition_var =
PADDLE_GET(framework::VarDesc *, condition_var_ptr);
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto condition_shape = condition_var->GetShape();
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t condition_rank = condition_shape.size();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(
condition_rank,
x_rank,
platform::errors::InvalidArgument(
"The dimensions of condtion and Inputs(X) should be same, "
"but get %d and %d",
condition_rank,
x_rank));
PADDLE_ENFORCE_EQ(
x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of Inputs(X) and Inputs(Y) should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < condition_rank; ++i) {
PADDLE_ENFORCE_EQ(condition_shape[i],
x_shape[i],
platform::errors::InvalidArgument(
"The shape of condition and Inputs(X) at dimension "
"%d should be same, "
"but get %d and %d",
i,
condition_shape[i],
x_shape[i]));
}
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of Inputs(X) and Inputs(Y) at dimension "
"%d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(condition_shape);
}
};
class SelectPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(select_p,
paddle::operators::SelectPrimOp,
paddle::operators::SelectPrimOpMaker,
paddle::operators::SelectPrimOpShapeInference,
paddle::operators::SelectPrimOpVarTypeInference);
......@@ -20,4 +20,5 @@ set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160)
set_tests_properties(test_minimize PROPERTIES TIMEOUT 60)
if(NOT WIN32)
set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60)
set_tests_properties(test_primapi PROPERTIES TIMEOUT 60)
endif()
......@@ -867,5 +867,120 @@ class TestScatterAddPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestSelectPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'select_p'
Cond = paddle.static.data(name='Condition', shape=[9, 5], dtype='bool')
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[9, 5], dtype='float64')
self.prim_input = {'Condition': Cond, 'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
Cond_DOT = paddle.static.data(name='Cond_DOT',
shape=[9, 5],
dtype='float64')
X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 5], dtype='float64')
self.jvp_args = (Cond_DOT, X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: True
Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y}
self.all_ops = [
# prim op:
'select_p',
# jvp op:
'select_p',
# transpose op:
'fill_constant_p',
'fill_constant_p',
'fill_constant_p',
'select_p',
'select_p',
]
class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'eq_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'eq_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'pow_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='float32')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'pow_p',
# jvp op:
'fill_constant_p',
'fill_constant_p',
'eq_p',
'select_p',
'sub_p',
'mul_p',
'mul_p',
'pow_p',
'mul_p',
'mul_p',
'log_p',
'add_p'
# transpose op:
]
if __name__ == '__main__':
unittest.main()
......@@ -481,5 +481,63 @@ class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestWhereOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'where'
Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool')
X = paddle.static.data(name='X', shape=[5, 6], dtype='float32')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32')
self.input = {'Condition': Cond, 'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (Cond, X, Y)
self.all_ops = ['where', 'select_p']
self.out_map = {0: self.output['Out']}
class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['equal', 'eq_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestPowOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_pow'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_pow', 'pow_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -497,5 +497,62 @@ class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestSelectPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'select_p'
Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool')
X = paddle.static.data(name='X', shape=[5, 6], dtype='float32')
Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32')
self.input = {'Condition': Cond, 'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (Cond, X, Y)
self.all_ops = ['select_p', 'where']
self.out_map = {self.output['Z']: 0}
class TestEqPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'eq_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['eq_p', 'equal']
self.out_map = {self.output['Z']: 0}
class TestPowPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'pow_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['pow_p', 'elementwise_pow']
self.out_map = {self.output['Z']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,9 @@ import typing
import unittest
import numpy as np
import autograd
import autograd.numpy as np_autograd
import paddle
import config
......@@ -148,6 +151,8 @@ class TestWithoutProgramGuard(unittest.TestCase):
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
))
# paddle.where, paddle.pow has no double grad definition,
# can not compute forward grad use double trick
class TestForwardGrad(unittest.TestCase):
@classmethod
......@@ -239,24 +244,36 @@ class TestForwardGrad(unittest.TestCase):
paddle.incubate.autograd.disable_prim()
where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), (
('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
(np.random.rand(5, 5), ), None, 'float64'),
('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), ), 'float64'),
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
))
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
(
('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
(np.random.rand(5, 5), ), None, 'float64'),
('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), ), 'float64'),
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
# In where op, grad of condition computed by paddle.static.gradients is None,
# and paddle.incubate.autograd.grad will replace None with zeros while transpose
# will just return None because cond_dot is unused, that is a diff.
('select', where_wrap,
(np.random.rand(3, 4), np.random.rand(3, 4)), None, 'float32'),
# pow_p and pow has diff when compute z_dot of 0^0
('pow', paddle.pow,
(np.array([1, 2, 3]), np.array([0, 2, 7])), None, 'float32'),
))
class TestGrad(unittest.TestCase):
def setUp(self):
......@@ -367,6 +384,33 @@ class TestGrad(unittest.TestCase):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
def multiply_pd(x):
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x2)
x4 = paddle.multiply(x3, x)
return x4
multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0]
sin_ag = lambda xs: np_autograd.sin(xs[0])
cos_ag = lambda xs: np_autograd.cos(xs[0])
exp_ag = lambda xs: np_autograd.exp(xs[0])
pow_ag = lambda xs: xs[0]**xs[1]
log_ag = lambda xs: np_autograd.log(xs[0])
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), (
('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
))
class TestGradWithHigherOrder(unittest.TestCase):
def setUp(self):
......@@ -377,105 +421,58 @@ class TestGradWithHigherOrder(unittest.TestCase):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_third_order(self):
paddle.incubate.autograd.enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
grad1, = paddle.incubate.autograd.grad([x4], [x])
grad2, = paddle.incubate.autograd.grad([grad1], [x])
grad3, = paddle.incubate.autograd.grad([grad2], [x])
paddle.incubate.autograd.prim2orig(main.block(0))
feed = {x.name: np.array([2.]).astype('float32')}
fetch_list = [grad3.name]
result = [np.array([48.])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim()
@classmethod
def setUpClass(cls):
cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
cls._rtol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("rtol")
cls._atol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("atol")
def test_fourth_order(self):
paddle.incubate.autograd.enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
out = paddle.sqrt(x5 + x4)
grad1, = paddle.incubate.autograd.grad([out], [x])
grad2, = paddle.incubate.autograd.grad([grad1], [x])
grad3, = paddle.incubate.autograd.grad([grad2], [x])
grad4, = paddle.incubate.autograd.grad([grad3], [x])
paddle.incubate.autograd.prim2orig(main.block(0))
feed = {
x.name: np.array([2.]).astype('float32'),
}
fetch_list = [grad4.name]
# (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5)
result = [np.array([-0.27263762711])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim()
def test_grad(self):
def test_fifth_order(self):
paddle.incubate.autograd.enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
x6 = paddle.multiply(x5, x)
out = x6 + x5
grad1, = paddle.incubate.autograd.grad([out], [x])
grad2, = paddle.incubate.autograd.grad([grad1], [x])
grad3, = paddle.incubate.autograd.grad([grad2], [x])
grad4, = paddle.incubate.autograd.grad([grad3], [x])
grad5, = paddle.incubate.autograd.grad([grad4], [x])
paddle.incubate.autograd.prim2orig()
feed = {
x.name: np.array([2.]).astype('float32'),
}
fetch_list = [grad5.name]
result = [np.array([1560.0])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim()
def expected():
egrad = autograd.elementwise_grad
grad_3 = egrad(egrad(egrad(self.fun_ag)))(self.xs)
grad_4 = egrad(egrad(egrad(egrad(self.fun_ag))))(self.xs)
grad_5 = egrad(egrad(egrad(egrad(egrad(self.fun_ag)))))(self.xs)
# the output of egrad is tuple
return list(grad_3 + grad_4 + grad_5)
def actual():
paddle_grad = paddle.incubate.autograd.grad
paddle.incubate.autograd.enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun_pd(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun_pd(static_xs)
grad1 = paddle_grad(ys, static_xs, static_v)
grad2 = paddle_grad(grad1, static_xs, static_v)
grad3 = paddle_grad(grad2, static_xs, static_v)
grad4 = paddle_grad(grad3, static_xs, static_v)
grad5 = paddle_grad(grad4, static_xs, static_v)
paddle.incubate.autograd.prim2orig()
fetch_list = [grad3, grad4, grad5]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
paddle.incubate.autograd.disable_prim()
return outs
actual = actual()
expected = expected()
self.assertEqual(type(actual), type(expected))
for i, j in zip(actual, expected):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
if __name__ == '__main__':
......
......@@ -95,6 +95,10 @@ paddle.enable_static()
'dtype': paddle.float32
}, (3, 2), 'float32'),
('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'),
('select', primops.select,
(randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
))
class TestPrimops(unittest.TestCase):
......
......@@ -137,11 +137,6 @@ def exp(x, out=None):
return _simple_unop(LayerHelper('exp_p', **locals()))
@REGISTER_FN('log_p', 'X', 'Y')
def log(x, out=None):
return _simple_unop(LayerHelper('log_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals()))
......@@ -315,3 +310,43 @@ def scatter_add(x, y, indextensor, axis, out=None):
outputs={'Z': out},
attrs=attrs)
return out
@REGISTER_FN('log_p', 'X', 'Y')
def log(x, out=None):
return _simple_unop(LayerHelper('log_p', **locals()))
@REGISTER_FN('select_p', 'Condition', 'X', 'Y', 'Z')
def select(cond, x, y, out=None):
if len(cond.shape) != len(x.shape):
raise ValueError(
"len(cond.shape) should be equal to len(x.shape), but len(cond.shape)={} and len(x.shape)={}."
.format(len(cond.shape), len(x.shape)))
if len(x.shape) != len(y.shape):
raise ValueError(
"len(x.shape) should be equal to len(y.shape), but len(x.shape)={} and len(y.shape)={}."
.format(len(x.shape), len(y.shape)))
helper = LayerHelper('select_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=helper.layer_type,
inputs={
'Condition': cond,
'X': x,
'Y': y
},
outputs={'Z': out})
return out
@REGISTER_FN('eq_p', 'X', 'Y', 'Z')
def eq(x, y, out=None):
return _simple_binop(LayerHelper('eq_p', **locals()))
@REGISTER_FN('pow_p', 'X', 'Y', 'Z')
def pow(x, y, out=None):
return _simple_binop(LayerHelper('pow_p', **locals()))
......@@ -15,10 +15,11 @@ import typing
import paddle
from . import primops
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, log)
transpose, log, select, eq)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -66,6 +67,10 @@ index_select
scale
assign
sqrt
log
select
equal
elementwise_pow
These original ops are partially supported:
......@@ -290,6 +295,28 @@ def p_norm_orig2prim(op, x):
raise RuntimeError('Only support lower l2/l1 norm currently')
# TODO: support broadcast
@REGISTER_ORIG2PRIM('where')
def select_orig2prim(op, condition, x, y):
return select(condition, x, y)
@REGISTER_ORIG2PRIM('equal')
def equal_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return eq(x, y)
@REGISTER_ORIG2PRIM('elementwise_pow')
def elementwise_pow_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
z = primops.pow(x, y)
return z
## Register prim2orig lower rules
......@@ -424,6 +451,21 @@ def fill_constant_prim2orig(op):
dtype=INT_DTYPE_2_STRING[op.attr('dtype')])
@REGISTER_PRIM2ORIG('select_p')
def select_prim2orig(op, condition, x, y):
return paddle.where(condition, x, y)
@REGISTER_PRIM2ORIG('eq_p')
def eq_prim2orig(op, x, y):
return paddle.equal(x, y)
@REGISTER_PRIM2ORIG('pow_p')
def pow_prim2orig(op, x, y):
return paddle.pow(x, y)
## Register linearize rules
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
......@@ -646,6 +688,55 @@ def scatter_add_jvp(op, x_dot, y_dot):
return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis)
@REGISTER_JVP('select_p')
def select_jvp(op, cond_dot, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
cond, x, y = op_position_inputs(op)
if x_dot is None:
x_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
if y_dot is None:
y_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
return select(cond, x_dot, y_dot)
@REGISTER_JVP('eq_p')
def eq_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('pow_p')
def pow_jvp(op, x_dot, y_dot):
def _compute_t1(x, y):
zero_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
one_y = fill_const(value=1.0, shape=y.shape, dtype=y.dtype)
cond = eq(y, zero_y)
new_y = select(cond, one_y, sub(y, one_y))
t1 = mul(x_dot, mul(y, primops.pow(x, new_y)))
return t1
if x_dot is None and y_dot is None:
return None
x, y = op_position_inputs(op)
z = op_position_output(op)
if y_dot is None:
return _compute_t1(x, y)
elif x_dot is None:
return mul(y_dot, mul(log(x), z))
else:
t1, t2 = _compute_t1(x, y), mul(y_dot, mul(log(x), z))
z_dot = add(t1, t2)
return z_dot
## Register transpose rules
......@@ -831,3 +922,22 @@ def scatter_add_transpose(op, check_dot, z_bar):
y_bar = gather(z_bar, indextensor, axis=axis)
indextensor_bar = None
return x_bar, y_bar, indextensor_bar
@REGISTER_TRANSPOSE('select_p')
def select_transpose(op, check_dot, z_bar):
cond, x, y = op_position_inputs(op)
assert check_dot(cond) or check_dot(x) or check_dot(y), (
f'check_dot(cond) ^ (check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(cond)={check_dot(cond)}, check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.'
)
zeros_x = fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
zeros_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
cond_bar = fill_const(value=0.0, shape=y.shape,
dtype=cond.dtype) if check_dot(cond) else None
x_bar = select(cond, z_bar, zeros_x) if check_dot(x) else None
y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None
return cond_bar, x_bar, y_bar
......@@ -14,3 +14,4 @@ scipy>=1.5; python_version == "3.6"
prettytable
distro
numpy>=1.20,<1.22; python_version >= "3.7"
autograd==1.4
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册