未验证 提交 4db8e5c7 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Add gather vjp (#50305)

* tmp gather vjp

* support gather

* remove useless code

* fix compiling error

* fix ut

* add eager test

* add eager test

* add seed

* fix cpu error

* fix transpose op compat

* remove tensor index case

* fix prim_cinn

* fix ut
上级 613a3ffe
...@@ -26,7 +26,8 @@ paddle/phi/api/lib/tensor_operants.cc ...@@ -26,7 +26,8 @@ paddle/phi/api/lib/tensor_operants.cc
paddle/phi/extension.h paddle/phi/extension.h
paddle/phi/include/* paddle/phi/include/*
paddle/phi/infermeta/generated.* paddle/phi/infermeta/generated.*
paddle/fluid/prim/api/generated_prim/*.cc
paddle/fluid/prim/api/generated_prim/*.h
*.DS_Store *.DS_Store
*.vs *.vs
build/ build/
......
...@@ -61,10 +61,10 @@ class ElementwiseAddCompositeGradOpMaker ...@@ -61,10 +61,10 @@ class ElementwiseAddCompositeGradOpMaker
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y"); paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx); auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx); std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y"); paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy); auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy); std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis")); int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(6) << "Runing add_grad composite func"; VLOG(6) << "Runing add_grad composite func";
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
...@@ -132,6 +134,33 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -132,6 +134,33 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected:
void Apply() override {
paddle::experimental::Tensor index = this->GetSingleForwardInput("Index");
paddle::optional<paddle::experimental::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("Axis");
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor dout = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(*dx_ptr);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing gather_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for gather composite "
"grad for now. "));
} else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, false, dx_ptr);
}
this->RecoverOutputName(dx, dx_name);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X"); DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
...@@ -146,6 +175,7 @@ REGISTER_OPERATOR(gather, ...@@ -146,6 +175,7 @@ REGISTER_OPERATOR(gather,
ops::GatherOpMaker, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>, ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>, ops::GatherGradOpMaker<paddle::imperative::OpBase>,
ops::GatherCompositeGradOpMaker,
GatherInferShapeFunctor); GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad, DECLARE_INFER_SHAPE_FUNCTOR(gather_grad,
GatherGradInferShapeFunctor, GatherGradInferShapeFunctor,
......
...@@ -23,4 +23,5 @@ ...@@ -23,4 +23,5 @@
- scatter - scatter
- scatter_nd_add - scatter_nd_add
- tile - tile
- transpose
- subtract - subtract
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# prim api list
white_ops_list = [
"pow",
"scale",
"multiply",
"unsqueeze",
"expand",
"full",
"reshape",
"divide",
"sum",
"exp",
"scatter",
"transpose",
]
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&",
}
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
}
class BaseAPI:
def __init__(self, api_item_yaml, prims=tuple()):
# self.api = api_item_yaml['op']
self.api = api_item_yaml['name']
self.is_prim_api = False
if api_item_yaml['name'] in prims:
self.is_prim_api = True
#######################################
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# outputs:
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
########################################
if self.is_prim_api:
(
self.inputs,
self.attrs,
self.outputs,
self.optional_vars,
) = self.parse_args(self.api, api_item_yaml)
self.inplace_map = api_item_yaml['inplace']
def get_api_func_name(self):
return self.api
# def is_inplace(self):
# if self.inplace_map
# return True
# return False
def get_input_tensor_args(self, inplace_flag=False):
input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
"const std::vector<Tensor>&": "std::vector<Tensor>&",
"const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(
inplace_type_map[self.inputs['input_info'][name]]
+ ' '
+ name
)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(
self.attrs['attr_info'][name][0] + ' ' + name + default_value
)
return ", ".join(declare_args)
def get_declare_args_nodefault(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(declare_args)
def get_return_type(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type]
)
else:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
for input_dict in api_item_yaml['inputs']:
if input_dict['optional']:
optional_vars.append(input_dict['name'])
inputs, attrs = self.parse_input_and_attr(
api_item_yaml['inputs'], api_item_yaml['attrs']
)
output_type_list, output_names, out_size_expr = self.parse_output(
api_item_yaml['outputs']
)
return (
inputs,
attrs,
{
'names': output_names,
'types': output_type_list,
'out_size_expr': out_size_expr,
},
optional_vars,
)
def parse_input_and_attr(self, inputs_list, attrs_list):
input_types_map = {
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&',
}
attr_types_map = {
'IntArray': 'const IntArray&',
'Scalar': 'const Scalar&',
'Scalar(int)': 'const Scalar&',
'Scalar(int64_t)': 'const Scalar&',
'Scalar(float)': 'const Scalar&',
'Scalar(dobule)': 'const Scalar&',
'Scalar[]': 'const std::vector<phi::Scalar>&',
'int': 'int',
'int32_t': 'int32_t',
'int64_t': 'int64_t',
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'float[]': 'const std::vector<float>&',
'double': 'double',
'bool': 'bool',
'bool[]': 'const std::vector<bool>&',
'str': 'const std::string&',
'str[]': 'const std::vector<std::string>&',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&',
}
optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>',
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<const Place&>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>',
}
inputs = {'names': [], 'input_info': {}}
for input_dict in inputs_list:
inputs['names'].append(input_dict['name'])
if input_dict['optional']:
inputs['input_info'][input_dict['name']] = optional_types_trans[
input_dict['typename']
]
else:
inputs['input_info'][input_dict['name']] = input_types_map[
input_dict['typename']
]
attrs = {'names': [], 'attr_info': {}}
for attr_dict in attrs_list:
attrs['names'].append(attr_dict['name'])
if 'default_value' in attr_dict.keys():
default_value = attr_dict['default_value']
else:
default_value = None
if 'optional' in attr_dict.keys():
attrs['attr_info'][attr_dict['name']] = (
optional_types_trans[attr_dict['typename']],
default_value,
)
else:
attrs['attr_info'][attr_dict['name']] = (
attr_types_map[attr_dict['typename']],
default_value,
)
return inputs, attrs
def parse_output(self, outputs_list):
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_dict in outputs_list:
if output_dict['intermediate']:
continue
out_type_list.append(output_dict['typename'])
out_name_list.append(output_dict['name'])
if 'size' in output_dict.keys():
out_size_expr_list.append(output_dict['size'])
else:
out_size_expr_list.append(None)
return out_type_list, out_name_list, out_size_expr_list
class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml, prims=tuple()):
super().__init__(api_item_yaml, prims)
def get_api__func_name(self):
api_func_name = self.api
# if self.is_inplace:
# if api_func_name[-1] != '_':
# api_func_name += '_'
# print("after api name", api_func_name)
return api_func_name
def gene_prim_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api__func_name()
if api_func_name[-1] != '_':
api_declaration = f"""
template <typename T>
{self.get_return_type()} {api_func_name}({self.get_declare_args()});
"""
else:
api_declaration = (
api_declaration
+ f"""
template <typename T>
{self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
"""
)
return api_declaration
def get_ad_func_input_args(self, inplace_flag=False):
input_args = []
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(name)
else:
input_args.append(name)
return input_args
def get_ad_func_args(self, inplace_flag=False):
ad_func_args = self.get_ad_func_input_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
ad_func_args.append(name)
ad_func_args_str = ", ".join(ad_func_args)
return ad_func_args_str
def gene_ad_func_call(self):
api_func_name = self.get_api__func_name()
dygraph_ad_func_name = '::' + api_func_name + '_ad_func'
dygraph_ad_func_parameters = self.get_ad_func_args()
ad_func_call_str = f"""
VLOG(4) << "Eager Prim API {api_func_name}_ad_func call";
return {dygraph_ad_func_name}({dygraph_ad_func_parameters});
"""
# print("ad_func_call_str: ", ad_func_call_str)
return ad_func_call_str
def gene_eager_prim_api_code(self):
api_code = ""
indent = " "
api_func_name = self.get_api__func_name()
template = '<Tensor>'
# func decalaration
if api_func_name[-1] != '_':
api_code = f"""
template <>
{self.get_return_type()} {api_func_name}{template}({self.get_declare_args_nodefault()})
"""
else:
api_code = f"""
template <>
{self.get_return_type(inplace_flag=True)} {api_func_name}{template}({self.get_declare_args_nodefault(inplace_flag=True)})
"""
# func code
api_code = api_code + '{'
api_code += f"""{self.gene_ad_func_call()}"""
api_code += '}' + '\n'
return api_code
...@@ -24,6 +24,40 @@ using IntArray = ...@@ -24,6 +24,40 @@ using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>; paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// This function should have as same signature as phi, which defined in // This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h // paddle/phi/api/backward/backward_api.h
template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
const Scalar& axis,
bool overwrite,
Tensor* grad_x) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int> tmp_perm;
// change axis to rank 0
int axis_value = axis.to<int>();
tmp_perm.push_back(axis_value);
// make other ranks
for (int i = 0; i < x.dims().size(); ++i) {
if (i != axis_value) {
tmp_perm.push_back(i);
}
}
std::vector<int> reverse_perm(tmp_perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
reverse_perm[tmp_perm[i]] = i;
}
// transpose out_grad and zero grad to target rank.
auto tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
auto tmp_out_grad = transpose<T>(out_grad, tmp_perm);
// scatter grad to grad_x
auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
auto tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
set_output<T>(tmp_grad_x_tranposed, grad_x);
}
template <typename T> template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return; if (!grad_x) return;
......
...@@ -38,9 +38,9 @@ namespace prim { ...@@ -38,9 +38,9 @@ namespace prim {
/* /*
This functor class is responsible for creating the gradient ops for the given This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of operator fwd_op_. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op) will be added to (gradient variable, corresponding input variable of fwd_op_) will be added to
grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its grad_to_var. If an input variable of fwd_op_ is contained in no_grad_set, its
gradient variable will be ignored or kEmptyVarName depending on the template gradient variable will be ignored or kEmptyVarName depending on the template
argument DropEmptyIG in the derived classes. argument DropEmptyIG in the derived classes.
*/ */
...@@ -114,34 +114,40 @@ class CompositeGradOpMakerBase { ...@@ -114,34 +114,40 @@ class CompositeGradOpMakerBase {
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardOutput( paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardOutput(
const std::string& name) { const std::string& name) {
paddle::optional<paddle::experimental::Tensor> output_opt; paddle::optional<paddle::experimental::Tensor> output_opt;
if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) {
framework::VarDesc* output_desc = this->SingleForwardOutput(name); framework::VarDesc* output_desc = this->SingleForwardOutput(name);
if (!output_desc) return output_opt; if (!output_desc) return output_opt;
paddle::experimental::Tensor output = paddle::experimental::Tensor output = paddle::experimental::Tensor(
paddle::experimental::Tensor(std::make_shared<DescTensor>(output_desc)); std::make_shared<DescTensor>(output_desc));
output_opt = paddle::make_optional<paddle::experimental::Tensor>(output); output_opt = paddle::make_optional<paddle::experimental::Tensor>(output);
}
return output_opt; return output_opt;
} }
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardInput( paddle::optional<paddle::experimental::Tensor> GetOptionalSingleForwardInput(
const std::string& name) { const std::string& name) {
paddle::optional<paddle::experimental::Tensor> input_opt; paddle::optional<paddle::experimental::Tensor> input_opt;
if (fwd_op_.Inputs().find(name) != fwd_op_.Inputs().end()) {
framework::VarDesc* input_desc = this->SingleForwardInput(name); framework::VarDesc* input_desc = this->SingleForwardInput(name);
if (!input_desc) return input_opt; if (!input_desc) return input_opt;
paddle::experimental::Tensor input = paddle::experimental::Tensor input = paddle::experimental::Tensor(
paddle::experimental::Tensor(std::make_shared<DescTensor>(input_desc)); std::make_shared<DescTensor>(input_desc));
input_opt = paddle::make_optional<paddle::experimental::Tensor>(input); input_opt = paddle::make_optional<paddle::experimental::Tensor>(input);
}
return input_opt; return input_opt;
} }
paddle::optional<paddle::experimental::Tensor> GetOptionalSingleOutputGrad( paddle::optional<paddle::experimental::Tensor> GetOptionalSingleOutputGrad(
const std::string& name) { const std::string& name) {
paddle::optional<paddle::experimental::Tensor> output_grad_opt; paddle::optional<paddle::experimental::Tensor> output_grad_opt;
if (fwd_op_.Outputs().find(name) != fwd_op_.Outputs().end()) {
framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name); framework::VarDesc* output_grad_desc = this->SingleOutputGrad(name);
if (!output_grad_desc) return output_grad_opt; if (!output_grad_desc) return output_grad_opt;
paddle::experimental::Tensor output_grad = paddle::experimental::Tensor( paddle::experimental::Tensor output_grad = paddle::experimental::Tensor(
std::make_shared<DescTensor>(output_grad_desc)); std::make_shared<DescTensor>(output_grad_desc));
output_grad_opt = output_grad_opt =
paddle::make_optional<paddle::experimental::Tensor>(output_grad); paddle::make_optional<paddle::experimental::Tensor>(output_grad);
}
return output_grad_opt; return output_grad_opt;
} }
...@@ -457,16 +463,44 @@ class CompositeGradOpMakerBase { ...@@ -457,16 +463,44 @@ class CompositeGradOpMakerBase {
framework::VarDesc* SingleForwardInput(const std::string& name) const { framework::VarDesc* SingleForwardInput(const std::string& name) const {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
auto fwd_in_names = fwd_op_.Input(name);
if (!fwd_in_names.empty()) {
PADDLE_ENFORCE_EQ(
fwd_in_names.size(),
1,
phi::errors::InvalidArgument(
"When calling SingleForward for op: %s's Input: %s, we should "
"only get one input tensor, but we got %d instead.",
fwd_op_.Type(),
name,
fwd_in_names.size()));
CopyVarFromOrig(fwd_op_.Input(name).at(0)); CopyVarFromOrig(fwd_op_.Input(name).at(0));
return StaticCompositeContext::Instance().GetBlock()->FindVar( return StaticCompositeContext::Instance().GetBlock()->FindVar(
fwd_op_.Input(name).at(0)); fwd_op_.Input(name).at(0));
} else {
return nullptr;
}
} }
framework::VarDesc* SingleForwardOutput(const std::string& name) const { framework::VarDesc* SingleForwardOutput(const std::string& name) const {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
auto fwd_out_names = fwd_op_.Output(name);
if (!fwd_out_names.empty()) {
PADDLE_ENFORCE_EQ(
fwd_out_names.size(),
1,
phi::errors::InvalidArgument(
"When calling SingleForward for op: %s's Output: %s, we should "
"only get one input tensor, but we got %d instead.",
fwd_op_.Type(),
name,
fwd_out_names.size()));
CopyVarFromOrig(fwd_op_.Output(name).at(0)); CopyVarFromOrig(fwd_op_.Output(name).at(0));
return StaticCompositeContext::Instance().GetBlock()->FindVar( return StaticCompositeContext::Instance().GetBlock()->FindVar(
fwd_op_.Output(name).at(0)); fwd_op_.Output(name).at(0));
} else {
return nullptr;
}
} }
std::vector<framework::VarDesc*> MultiForwardInput( std::vector<framework::VarDesc*> MultiForwardInput(
......
...@@ -1675,7 +1675,10 @@ ...@@ -1675,7 +1675,10 @@
- op : transpose (transpose2) - op : transpose (transpose2)
backward : transpose_grad (transpose2_grad) backward : transpose_grad (transpose2_grad)
attrs:
perm : axis
extra : extra :
outputs : [XShape]
attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false, attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false,
str mkldnn_data_type = "float32"] str mkldnn_data_type = "float32"]
......
...@@ -129,8 +129,17 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -129,8 +129,17 @@ class TestPrimForwardAndBackward(unittest.TestCase):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
all_ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
.train_program.block(0)
.ops
]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops) self.assertTrue('softmax' not in fwd_ops)
for op in all_ops:
if op != "matmul_v2_grad":
self.assertTrue("_grad" not in op)
def test_cinn_prim(self): def test_cinn_prim(self):
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v'),
[
(
np.random.rand(100),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3),
),
(
np.random.rand(10, 20),
np.array([1, 3, 5]),
0,
np.float64,
np.int64,
np.random.rand(3, 20),
),
(
np.random.rand(10, 20),
np.array([1, 1, 3]),
0,
np.float32,
np.int32,
np.random.rand(3, 20),
),
(
np.random.rand(3, 88, 30),
np.array([1, 3, 5]),
1,
np.float32,
np.int32,
np.random.rand(3, 3, 30),
),
(
np.random.rand(10, 88, 10),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3, 88, 10),
),
],
)
class TestGatherGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.x_dtype)
cls.index = cls.index.astype(cls.index_dtype)
cls.v = cls.v.astype(cls.x_dtype)
@classmethod
def tearDownClass(cls):
core._set_prim_backward_enabled(False)
def test_exp_grad_comp(self):
def actual(primal0, index, axis):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
)
index = paddle.to_tensor(index, dtype=index.dtype)
x.stop_gradient = False
index.stop_gradient = True
out = paddle.gather(x, index, axis)
res = paddle.grad(out, [x], create_graph=False, retain_graph=True)
return res[0].numpy()
def desired(primal0, index, axis):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(
primal0, dtype=primal0.dtype, stop_gradient=False
)
index = paddle.to_tensor(index, dtype=index.dtype)
x.stop_gradient = False
index.stop_gradient = True
out = paddle.gather(x, index, axis)
res = paddle.grad(out, [x], create_graph=False, retain_graph=True)
return res[0].numpy()
np.testing.assert_allclose(
actual=actual(self.primal0, self.index, self.axis),
desired=desired(self.primal0, self.index, self.axis),
rtol=1e-6,
atol=0,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
np.random.seed(2023)
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x, index, axis):
tmp = self.fc(x)
out = paddle.gather(tmp, index, axis)
return out
@param.parameterized_class(
('primal0', 'index', 'axis', 'x_dtype', 'index_dtype', 'v', "count"),
[
(
np.random.rand(100),
np.array([1, 3, 5]),
0,
np.float32,
np.int32,
np.random.rand(3),
0,
),
(
np.random.rand(10, 20),
np.array([1, 3, 5]),
0,
np.float64,
np.int64,
np.random.rand(3, 20),
1,
),
(
np.random.rand(10, 20),
np.array([1, 1, 3]),
0,
np.float32,
np.int32,
np.random.rand(3, 20),
2,
),
(
# Something wrong with gather grad cpu kernel
np.random.rand(3, 88, 30),
np.array([1, 3, 5]),
1,
np.float32,
np.int32,
np.random.rand(3, 3, 30),
3,
),
(
np.random.rand(10, 88, 10),
np.array([1, 3, 5]),
0,
np.float16,
np.int32,
np.random.rand(3, 88, 10),
4,
),
],
)
class TestGatherGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.x_dtype)
cls.index = cls.index.astype(cls.index_dtype)
cls.v = cls.v.astype(cls.x_dtype)
def train(self, use_prim, use_cinn):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.index = paddle.to_tensor(np.array([0, 1]))
self.x.stop_gradient = False
net = PrimeNet()
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x, self.index, 0)
res = paddle.autograd.grad(out, [self.x])
return res
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
# TODO(jiabin): CINN will crashed in this case open it when fixed
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-6,
atol=1e-6,
)
paddle.enable_static()
def test_tanh_grad_comp(self):
paddle.enable_static()
def actual(primal0, index, axis, v):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
index_tmp = paddle.static.data(
'index', index.shape, index.dtype
)
x.stop_gradient = False
index_tmp.stop_gradient = True
z = paddle.gather(x, index_tmp, axis)
z_grad = paddle.static.data('v', z.shape, z.dtype)
res = paddle.static.gradients([z], [x], [z_grad])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'index': index,
'v': v,
},
fetch_list=[res[0].name],
)
return out[0]
def desired(primal0, index, axis, v):
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
index_tmp = paddle.static.data(
'index', index.shape, index.dtype
)
x.stop_gradient = False
index_tmp.stop_gradient = True
z = paddle.gather(x, index_tmp, axis)
z_grad = paddle.static.data('v', z.shape, z.dtype)
res = paddle.static.gradients([z], [x], [z_grad])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'index': index,
'v': v,
},
fetch_list=[res[0].name],
)
return out[0]
dx = None
ddx = None
# fp16 is not supported for cpu gather
if not (
(self.count == 4)
and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
)
):
dx = actual(self.primal0, self.index, self.axis, self.v)
ddx = desired(self.primal0, self.index, self.axis, self.v)
if (self.count >= 3) and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# Scatter in phi has problem with cpu kernel of case 4, so skip this
pass
elif (self.count == 4) and (
not isinstance(
framework._current_expected_place(), framework.core.CPUPlace
)
):
# FP16 test case
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-3,
atol=0,
)
elif self.count == 1:
# FP64 test case
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-15,
atol=1e-15,
)
else:
# FP32 test cases
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-5,
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册