未验证 提交 4db8e5c7 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Add gather vjp (#50305)

* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* fix ut
上级 613a3ffe
......@@ -26,7 +26,8 @@ paddle/phi/api/lib/tensor_operants.cc
paddle/phi/extension.h
paddle/phi/include/*
paddle/phi/infermeta/generated.*
paddle/fluid/prim/api/generated_prim/*.cc
paddle/fluid/prim/api/generated_prim/*.h
*.DS_Store
*.vs
build/
......
......@@ -61,10 +61,10 @@ class ElementwiseAddCompositeGradOpMaker
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(6) << "Runing add_grad composite func";
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
......@@ -132,6 +134,33 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected:
void Apply() override {
paddle::experimental::Tensor index = this->GetSingleForwardInput("Index");
paddle::optional<paddle::experimental::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("Axis");
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor dout = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(*dx_ptr);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing gather_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for gather composite "
"grad for now. "));
} else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, false, dx_ptr);
}
this->RecoverOutputName(dx, dx_name);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace operators
......@@ -146,6 +175,7 @@ REGISTER_OPERATOR(gather,
ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>,
ops::GatherCompositeGradOpMaker,
GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad,
GatherGradInferShapeFunctor,
......
......@@ -23,4 +23,5 @@
- scatter
- scatter_nd_add
- tile
- transpose
- subtract
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# prim api list
white_ops_list = [
"pow",
"scale",
"multiply",
"unsqueeze",
"expand",
"full",
"reshape",
"divide",
"sum",
"exp",
"scatter",
"transpose",
]
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&",
}
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
}
class BaseAPI:
def __init__(self, api_item_yaml, prims=tuple()):
# self.api = api_item_yaml['op']
self.api = api_item_yaml['name']
self.is_prim_api = False
if api_item_yaml['name'] in prims:
self.is_prim_api = True
#######################################
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# outputs:
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
########################################
if self.is_prim_api:
(
self.inputs,
self.attrs,
self.outputs,
self.optional_vars,
) = self.parse_args(self.api, api_item_yaml)
self.inplace_map = api_item_yaml['inplace']
def get_api_func_name(self):
return self.api
# def is_inplace(self):
# if self.inplace_map
# return True
# return False
def get_input_tensor_args(self, inplace_flag=False):
input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
"const std::vector<Tensor>&": "std::vector<Tensor>&",
"const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(
inplace_type_map[self.inputs['input_info'][name]]
+ ' '
+ name
)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(
self.attrs['attr_info'][name][0] + ' ' + name + default_value
)
return ", ".join(declare_args)
def get_declare_args_nodefault(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(declare_args)
def get_return_type(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type]
)
else:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
for input_dict in api_item_yaml['inputs']:
if input_dict['optional']:
optional_vars.append(input_dict['name'])
inputs, attrs = self.parse_input_and_attr(
api_item_yaml['inputs'], api_item_yaml['attrs']
)
output_type_list, output_names, out_size_expr = self.parse_output(
api_item_yaml['outputs']
)
return (
inputs,
attrs,
{
'names': output_names,
'types': output_type_list,
'out_size_expr': out_size_expr,
},
optional_vars,
)
def parse_input_and_attr(self, inputs_list, attrs_list):
input_types_map = {
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&',
}
attr_types_map = {
'IntArray': 'const IntArray&',
'Scalar': 'const Scalar&',
'Scalar(int)': 'const Scalar&',
'Scalar(int64_t)': 'const Scalar&',
'Scalar(float)': 'const Scalar&',
'Scalar(dobule)': 'const Scalar&',
'Scalar[]': 'const std::vector<phi::Scalar>&',
'int': 'int',
'int32_t': 'int32_t',
'int64_t': 'int64_t',
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'float[]': 'const std::vector<float>&',
'double': 'double',
'bool': 'bool',
'bool[]': 'const std::vector<bool>&',
'str': 'const std::string&',
'str[]': 'const std::vector<std::string>&',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&',
}
optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>',
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<const Place&>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>',
}
inputs = {'names': [], 'input_info': {}}
for input_dict in inputs_list:
inputs['names'].append(input_dict['name'])
if input_dict['optional']:
inputs['input_info'][input_dict['name']] = optional_types_trans[
input_dict['typename']
]
else:
inputs['input_info'][input_dict['name']] = input_types_map[
input_dict['typename']
]
attrs = {'names': [], 'attr_info': {}}
for attr_dict in attrs_list:
attrs['names'].append(attr_dict['name'])
if 'default_value' in attr_dict.keys():
default_value = attr_dict['default_value']
else:
default_value = None
if 'optional' in attr_dict.keys():
attrs['attr_info'][attr_dict['name']] = (
optional_types_trans[attr_dict['typename']],
default_value,
)
else:
attrs['attr_info'][attr_dict['name']] = (
attr_types_map[attr_dict['typename']],
default_value,
)
return inputs, attrs
def parse_output(self, outputs_list):
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_dict in outputs_list:
if output_dict['intermediate']:
continue
out_type_list.append(output_dict['typename'])
out_name_list.append(output_dict['name'])
if 'size' in output_dict.keys():
out_size_expr_list.append(output_dict['size'])
else:
out_size_expr_list.append(None)
return out_type_list, out_name_list, out_size_expr_list
class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml, prims=tuple()):
super().__init__(api_item_yaml, prims)
def get_api__func_name(self):
api_func_name = self.api
# if self.is_inplace:
# if api_func_name[-1] != '_':
# api_func_name += '_'
# print("after api name", api_func_name)
return api_func_name
def gene_prim_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api__func_name()
if api_func_name[-1] != '_':
api_declaration = f"""
template <typename T>
{self.get_return_type()} {api_func_name}({self.get_declare_args()});
"""
else:
api_declaration = (
api_declaration
+ f"""
template <typename T>
{self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
"""
)
return api_declaration
def get_ad_func_input_args(self, inplace_flag=False):
input_args = []
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(name)
else:
input_args.append(name)
return input_args
def get_ad_func_args(self, inplace_flag=False):
ad_func_args = self.get_ad_func_input_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
ad_func_args.append(name)
ad_func_args_str = ", ".join(ad_func_args)
return ad_func_args_str
def gene_ad_func_call(self):
api_func_name = self.get_api__func_name()
dygraph_ad_func_name = '::' + api_func_name + '_ad_func'
dygraph_ad_func_parameters = self.get_ad_func_args()
ad_func_call_str = f"""
VLOG(4) << "Eager Prim API {api_func_name}_ad_func call";
return {dygraph_ad_func_name}({dygraph_ad_func_parameters});
"""
# print("ad_func_call_str: ", ad_func_call_str)
return ad_func_call_str
def gene_eager_prim_api_code(self):
api_code = ""
indent = " "
api_func_name = self.get_api__func_name()
template = '<Tensor>'
# func decalaration
if api_func_name[-1] != '_':
api_code = f"""
template <>
{self.get_return_type()} {api_func_name}{template}({self.get_declare_args_nodefault()})
"""
else:
api_code = f"""
template <>
{self.get_return_type(inplace_flag=True)} {api_func_name}{template}({self.get_declare_args_nodefault(inplace_flag=True)})
"""
# func code
api_code = api_code + '{'
api_code += f"""{self.gene_ad_func_call()}"""
api_code += '}' + '\n'
return api_code
......@@ -24,6 +24,40 @@ using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
const Scalar& axis,
bool overwrite,
Tensor* grad_x) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int> tmp_perm;
// change axis to rank 0
int axis_value = axis.to<int>();
tmp_perm.push_back(axis_value);
// make other ranks
for (int i = 0; i < x.dims().size(); ++i) {
if (i != axis_value) {
tmp_perm.push_back(i);
}
}
std::vector<int> reverse_perm(tmp_perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
reverse_perm[tmp_perm[i]] = i;
}
// transpose out_grad and zero grad to target rank.
auto tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
auto tmp_out_grad = transpose<T>(out_grad, tmp_perm);
// scatter grad to grad_x
auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
auto tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
set_output<T>(tmp_grad_x_tranposed, grad_x);
}
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
......
......@@ -38,9 +38,9 @@ namespace prim {
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op) will be added to
grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its
operator fwd_op_. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op_) will be added to
grad_to_var. If an input variable of fwd_op_ is contained in no_grad_set, its
gradient variable will be ignored or kEmptyVarName depending on the template
argument DropEmptyIG in the derived classes.
*/
......@@ -114,34 +114,40 @@ class CompositeGradOpMakerBase {
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardOutput(
const std::string& name) {
paddle::optional<paddle::experimental::Tensor> output_opt;
if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) {
framework::VarDesc* output_desc = this->SingleForwardOutput(name);
if (!output_desc) return output_opt;
paddle::experimental::Tensor output =
paddle::experimental::Tensor(std::make_shared<DescTensor>(output_desc));
paddle::experimental::Tensor output = paddle::experimental::Tensor(
std::make_shared<DescTensor>(output_desc));
output_opt = paddle::make_optional<paddle::experimental::Tensor>(output);
}
return output_opt;
}
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardInput(
const std::string& name) {
paddle::optional<paddle::experimental::Tensor> input_opt;
if (fwd_op_.Inputs().find(name) != fwd_op_.Inputs().end()) {
framework::VarDesc* input_desc = this->SingleForwardInput(name);
if (!input_desc) return input_opt;
paddle::experimental::Tensor input =
paddle::experimental::Tensor(std::make_shared<DescTensor>(input_desc));
paddle::experimental::Tensor input = paddle::experimental::Tensor(
std::make_shared<DescTensor>(input_desc));
input_opt = paddle::make_optional<paddle::experimental::Tensor>(input);
}
return input_opt;
}
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleOutputGrad(
const std::string& name) {
paddle::optional<paddle::experimental::Tensor> output_grad_opt;
if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) {
framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name);
if (!output_grad_desc) return output_grad_opt;
paddle::experimental::Tensor output_grad = paddle::experimental::Tensor(
std::make_shared<DescTensor>(output_grad_desc));
output_grad_opt =
paddle::make_optional<paddle::experimental::Tensor>(output_grad);
}
return output_grad_opt;
}
......@@ -457,16 +463,44 @@ class CompositeGradOpMakerBase {
framework::VarDesc* SingleForwardInput(const std::string& name) const {
// Copy Var from original block to active block, or create a new one.
auto fwd_in_names = fwd_op_.Input(name);
if (!fwd_in_names.empty()) {
PADDLE_ENFORCE_EQ(
fwd_in_names.size(),
1,
phi::errors::InvalidArgument(
"When calling SingleForward for op: %s's Input: %s, we should "
"only get one input tensor, but we got %d instead.",
fwd_op_.Type(),
name,
fwd_in_names.size()));
CopyVarFromOrig(fwd_op_.Input(name).at(0));
return StaticCompositeContext::Instance().GetBlock()->FindVar(
fwd_op_.Input(name).at(0));
} else {
return nullptr;
}
}
framework::VarDesc* SingleForwardOutput(const std::string& name) const {
// Copy Var from original block to active block, or create a new one.
auto fwd_out_names = fwd_op_.Output(name);
if (!fwd_out_names.empty()) {
PADDLE_ENFORCE_EQ(
fwd_out_names.size(),
1,
phi::errors::InvalidArgument(
"When calling SingleForward for op: %s's Output: %s, we should "
"only get one input tensor, but we got %d instead.",
fwd_op_.Type(),
name,
fwd_out_names.size()));
CopyVarFromOrig(fwd_op_.Output(name).at(0));
return StaticCompositeContext::Instance().GetBlock()->FindVar(
fwd_op_.Output(name).at(0));
} else {
return nullptr;
}
}
std::vector<framework::VarDesc*> MultiForwardInput(
......
......@@ -1675,7 +1675,10 @@
- op : transpose (transpose2)
backward : transpose_grad (transpose2_grad)
attrs:
perm : axis
extra :
outputs : [XShape]
attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false,
str mkldnn_data_type = "float32"]
......
......@@ -129,8 +129,17 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
all_ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
.train_program.block(0)
.ops
]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
for op in all_ops:
if op != "matmul_v2_grad":
self.assertTrue("_grad" not in op)
def test_cinn_prim(self):
dy_res = self.train(use_prim=False)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v'),
[
(
np.random.rand(100),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3),
),
(
np.random.rand(10, 20),
np.array([1, 3, 5]),
0,
np.float64,
np.int64,
np.random.rand(3, 20),
),
(
np.random.rand(10, 20),
np.array([1, 1, 3]),
0,
np.float32,
np.int32,
np.random.rand(3, 20),
),
(
np.random.rand(3, 88, 30),
np.array([1, 3, 5]),
1,
np.float32,
np.int32,
np.random.rand(3, 3, 30),
),
(
np.random.rand(10, 88, 10),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3, 88, 10),
),
],
)
class TestGatherGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.x_dtype)
cls.index = cls.index.astype(cls.index_dtype)
cls.v = cls.v.astype(cls.x_dtype)
@classmethod
def tearDownClass(cls):
core._set_prim_backward_enabled(False)
def test_exp_grad_comp(self):
def actual(primal0, index, axis):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
)
index = paddle.to_tensor(index, dtype=index.dtype)
x.stop_gradient = False
index.stop_gradient = True
out = paddle.gather(x, index, axis)
res = paddle.grad(out, [x], create_graph=False, retain_graph=True)
return res[0].numpy()
def desired(primal0, index, axis):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
)
index = paddle.to_tensor(index, dtype=index.dtype)
x.stop_gradient = False
index.stop_gradient = True
out = paddle.gather(x, index, axis)
res = paddle.grad(out, [x], create_graph=False, retain_graph=True)
return res[0].numpy()
np.testing.assert_allclose(
actual=actual(self.primal0, self.index, self.axis),
desired=desired(self.primal0, self.index, self.axis),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
np.random.seed(2023)
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x, index, axis):
tmp = self.fc(x)
out = paddle.gather(tmp, index, axis)
return out
@param.parameterized_class(
('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v', "count"),
[
(
np.random.rand(100),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3),
0,
),
(
np.random.rand(10, 20),
np.array([1, 3, 5]),
0,
np.float64,
np.int64,
np.random.rand(3, 20),
1,
),
(
np.random.rand(10, 20),
np.array([1, 1, 3]),
0,
np.float32,
np.int32,
np.random.rand(3, 20),
2,
),
(
# Something wrong with gather grad cpu kernel
np.random.rand(3, 88, 30),
np.array([1, 3, 5]),
1,
np.float32,
np.int32,
np.random.rand(3, 3, 30),
3,
),
(
np.random.rand(10, 88, 10),
np.array([1, 3, 5]),
0,
np.float16,
np.int32,
np.random.rand(3, 88, 10),
4,
),
],
)
class TestGatherGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.x_dtype)
cls.index = cls.index.astype(cls.index_dtype)
cls.v = cls.v.astype(cls.x_dtype)
def train(self, use_prim, use_cinn):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.index = paddle.to_tensor(np.array([0, 1]))
self.x.stop_gradient = False
net = PrimeNet()
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.index, 0)
res = paddle.autograd.grad(out, [self.x])
return res
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
# TODO(jiabin): CINN will crashed in this case open it when fixed
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-6,
atol=1e-6,
)
paddle.enable_static()
def test_tanh_grad_comp(self):
paddle.enable_static()
def actual(primal0, index, axis, v):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
index_tmp = paddle.static.data(
'index', index.shape, index.dtype
)
x.stop_gradient = False
index_tmp.stop_gradient = True
z = paddle.gather(x, index_tmp, axis)
z_grad = paddle.static.data('v', z.shape, z.dtype)
res = paddle.static.gradients([z], [x], [z_grad])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'index': index,
'v': v,
},
fetch_list=[res[0].name],
)
return out[0]
def desired(primal0, index, axis, v):
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
index_tmp = paddle.static.data(
'index', index.shape, index.dtype
)
x.stop_gradient = False
index_tmp.stop_gradient = True
z = paddle.gather(x, index_tmp, axis)
z_grad = paddle.static.data('v', z.shape, z.dtype)
res = paddle.static.gradients([z], [x], [z_grad])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'index': index,
'v': v,
},
fetch_list=[res[0].name],
)
return out[0]
dx = None
ddx = None
# fp16 is not supported for cpu gather
if not (
(self.count == 4)
and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
)
):
dx = actual(self.primal0, self.index, self.axis, self.v)
ddx = desired(self.primal0, self.index, self.axis, self.v)
if (self.count >= 3) and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# Scatter in phi has problem with cpu kernel of case 4, so skip this
pass
elif (self.count == 4) and (
not isinstance(
framework._current_expected_place(), framework.core.CPUPlace
)
):
# FP16 test case
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-3,
atol=0,
)
elif self.count == 1:
# FP64 test case
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-15,
atol=1e-15,
)
else:
# FP32 test cases
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-5,
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册