未验证 提交 ab1b6303 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Reshape, transpose, cast vjp (#50778)

* support transpose and reshape

* support reshpe, transpose, cast vjp

* merge develop

* recover unused file

* remove prim base

* support problem

* remove additional status settting

* remove additional status settting

* fix ut

* fix ut

* fix ut

* fix no grad branch

* add more test

* disable fp16 in cpu

* fix test
上级 f8ce3a2c
......@@ -22,6 +22,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace operators {
......@@ -63,6 +66,24 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class CastCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
paddle::experimental::Tensor out_grad = paddle::experimental::Tensor(
std::make_shared<prim::DescTensor>(this->SingleOutputGrad("Out")));
paddle::experimental::Tensor x_grad = paddle::experimental::Tensor(
std::make_shared<prim::DescTensor>(this->SingleInputGrad("X")));
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dtype = static_cast<paddle::experimental::DataType>(
this->Attr<int>("in_dtype"));
prim::cast_grad<prim::DescTensor>(out_grad, dtype, dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
}
};
class CastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -134,6 +155,7 @@ REGISTER_OPERATOR(cast,
ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastCompositeGradOpMaker,
ops::CastOpProtoMaker);
// [ why register transfer_dtype_op alias with cast_op? ]
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/phi_utils.h"
// only can include the headers in paddle/phi/api dirs
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/infermeta_utils.h"
......@@ -26,7 +29,6 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
namespace paddle {
namespace framework {
class InferShapeContext;
......@@ -571,6 +573,25 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
}
};
class Reshape2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// We prefer to use x.shape instead of using xshape, this is different from
// PHI definition.
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto *dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
VLOG(6) << "Runing reshape2_grad composite func";
prim::reshape_grad<prim::DescTensor>(x, out_grad, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
template <typename T>
class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -715,6 +736,7 @@ REGISTER_OPERATOR(reshape2,
ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::Reshape2CompositeGradOpMaker,
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad,
ops::Reshape2GradOp,
......
......@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
namespace paddle {
namespace operators {
......@@ -300,6 +302,25 @@ class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
}
};
class Transpose2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::experimental::Tensor xshape =
this->GetSingleForwardOutput("XShape");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto *dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
std::vector<int> axis =
static_cast<std::vector<int>>(this->Attr<std::vector<int>>("axis"));
VLOG(6) << "Runing transpose2_grad composite func";
prim::transpose_grad<prim::DescTensor>(out_grad, axis, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
template <typename T>
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -365,7 +386,8 @@ REGISTER_OPERATOR(transpose2,
ops::Transpose2Op,
ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
ops::Transpose2GradMaker<paddle::imperative::OpBase>);
ops::Transpose2GradMaker<paddle::imperative::OpBase>,
ops::Transpose2CompositeGradOpMaker);
REGISTER_OPERATOR(transpose2_grad,
ops::Transpose2OpGrad,
ops::TransposeGradInferVarType,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# prim api list
white_ops_list = [
"pow",
"scale",
"multiply",
"unsqueeze",
"expand",
"full",
"reshape",
"divide",
"sum",
"exp",
"scatter",
"transpose",
]
inplace_out_type_map = {
"Tensor": "Tensor&",
"std::vector<Tensor>": "std::vector<Tensor>&",
}
inplace_optional_out_type_map = {
"Tensor": "paddle::optional<Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<Tensor>>&",
}
class BaseAPI:
def __init__(self, api_item_yaml, prims=tuple()):
# self.api = api_item_yaml['op']
self.api = api_item_yaml['name']
self.is_prim_api = False
if api_item_yaml['name'] in prims:
self.is_prim_api = True
#######################################
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# outputs:
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
########################################
if self.is_prim_api:
(
self.inputs,
self.attrs,
self.outputs,
self.optional_vars,
) = self.parse_args(self.api, api_item_yaml)
self.inplace_map = api_item_yaml['inplace']
def get_api_func_name(self):
return self.api
# def is_inplace(self):
# if self.inplace_map
# return True
# return False
def get_input_tensor_args(self, inplace_flag=False):
input_args = []
inplace_type_map = {
"const Tensor&": "Tensor&",
"const paddle::optional<Tensor>&": "paddle::optional<Tensor>&",
"const std::vector<Tensor>&": "std::vector<Tensor>&",
"const paddle::optional<std::vector<Tensor>>&": "paddle::optional<std::vector<Tensor>>&",
}
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(
inplace_type_map[self.inputs['input_info'][name]]
+ ' '
+ name
)
else:
input_args.append(self.inputs['input_info'][name] + ' ' + name)
return input_args
def get_declare_args(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(
self.attrs['attr_info'][name][0] + ' ' + name + default_value
)
return ", ".join(declare_args)
def get_declare_args_nodefault(self, inplace_flag=False):
declare_args = self.get_input_tensor_args(inplace_flag)
for name in self.attrs['names']:
declare_args.append(self.attrs['attr_info'][name][0] + ' ' + name)
return ", ".join(declare_args)
def get_return_type(self, inplace_flag=False):
out_type_list = []
for i, out_type in enumerate(self.outputs['types']):
out_name = self.outputs['names'][i].split('@')[0]
if inplace_flag and out_name in self.inplace_map:
if self.inplace_map[out_name] in self.optional_vars:
out_type_list.append(
inplace_optional_out_type_map[out_type]
)
else:
out_type_list.append(inplace_out_type_map[out_type])
else:
out_type_list.append(out_type)
if len(out_type_list) == 1:
return out_type_list[0]
else:
return "std::tuple<" + ", ".join(out_type_list) + ">"
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
for input_dict in api_item_yaml['inputs']:
if input_dict['optional']:
optional_vars.append(input_dict['name'])
inputs, attrs = self.parse_input_and_attr(
api_item_yaml['inputs'], api_item_yaml['attrs']
)
output_type_list, output_names, out_size_expr = self.parse_output(
api_item_yaml['outputs']
)
return (
inputs,
attrs,
{
'names': output_names,
'types': output_type_list,
'out_size_expr': out_size_expr,
},
optional_vars,
)
def parse_input_and_attr(self, inputs_list, attrs_list):
input_types_map = {
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&',
}
attr_types_map = {
'IntArray': 'const IntArray&',
'Scalar': 'const Scalar&',
'Scalar(int)': 'const Scalar&',
'Scalar(int64_t)': 'const Scalar&',
'Scalar(float)': 'const Scalar&',
'Scalar(dobule)': 'const Scalar&',
'Scalar[]': 'const std::vector<phi::Scalar>&',
'int': 'int',
'int32_t': 'int32_t',
'int64_t': 'int64_t',
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'float[]': 'const std::vector<float>&',
'double': 'double',
'bool': 'bool',
'bool[]': 'const std::vector<bool>&',
'str': 'const std::string&',
'str[]': 'const std::vector<std::string>&',
'Place': 'const Place&',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&',
}
optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>',
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Place': 'paddle::optional<const Place&>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>',
}
inputs = {'names': [], 'input_info': {}}
for input_dict in inputs_list:
inputs['names'].append(input_dict['name'])
if input_dict['optional']:
inputs['input_info'][input_dict['name']] = optional_types_trans[
input_dict['typename']
]
else:
inputs['input_info'][input_dict['name']] = input_types_map[
input_dict['typename']
]
attrs = {'names': [], 'attr_info': {}}
for attr_dict in attrs_list:
attrs['names'].append(attr_dict['name'])
if 'default_value' in attr_dict.keys():
default_value = attr_dict['default_value']
else:
default_value = None
if 'optional' in attr_dict.keys():
attrs['attr_info'][attr_dict['name']] = (
optional_types_trans[attr_dict['typename']],
default_value,
)
else:
attrs['attr_info'][attr_dict['name']] = (
attr_types_map[attr_dict['typename']],
default_value,
)
return inputs, attrs
def parse_output(self, outputs_list):
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_dict in outputs_list:
if output_dict['intermediate']:
continue
out_type_list.append(output_dict['typename'])
out_name_list.append(output_dict['name'])
if 'size' in output_dict.keys():
out_size_expr_list.append(output_dict['size'])
else:
out_size_expr_list.append(None)
return out_type_list, out_name_list, out_size_expr_list
class EagerPrimAPI(BaseAPI):
def __init__(self, api_item_yaml, prims=tuple()):
super().__init__(api_item_yaml, prims)
def get_api__func_name(self):
api_func_name = self.api
# if self.is_inplace:
# if api_func_name[-1] != '_':
# api_func_name += '_'
# print("after api name", api_func_name)
return api_func_name
def gene_prim_api_declaration(self):
api_declaration = ""
api_func_name = self.get_api__func_name()
if api_func_name[-1] != '_':
api_declaration = f"""
template <typename T>
{self.get_return_type()} {api_func_name}({self.get_declare_args()});
"""
else:
api_declaration = (
api_declaration
+ f"""
template <typename T>
{self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)});
"""
)
return api_declaration
def get_ad_func_input_args(self, inplace_flag=False):
input_args = []
for name in self.inputs['names']:
name = name.split('@')[0]
if inplace_flag and name in self.inplace_map.values():
input_args.append(name)
else:
input_args.append(name)
return input_args
def get_ad_func_args(self, inplace_flag=False):
ad_func_args = self.get_ad_func_input_args(inplace_flag)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
ad_func_args.append(name)
ad_func_args_str = ", ".join(ad_func_args)
return ad_func_args_str
def gene_ad_func_call(self):
api_func_name = self.get_api__func_name()
dygraph_ad_func_name = '::' + api_func_name + '_ad_func'
dygraph_ad_func_parameters = self.get_ad_func_args()
ad_func_call_str = f"""
VLOG(4) << "Eager Prim API {api_func_name}_ad_func call";
return {dygraph_ad_func_name}({dygraph_ad_func_parameters});
"""
# print("ad_func_call_str: ", ad_func_call_str)
return ad_func_call_str
def gene_eager_prim_api_code(self):
api_code = ""
indent = " "
api_func_name = self.get_api__func_name()
template = '<Tensor>'
# func decalaration
if api_func_name[-1] != '_':
api_code = f"""
template <>
{self.get_return_type()} {api_func_name}{template}({self.get_declare_args_nodefault()})
"""
else:
api_code = f"""
template <>
{self.get_return_type(inplace_flag=True)} {api_func_name}{template}({self.get_declare_args_nodefault(inplace_flag=True)})
"""
# func code
api_code = api_code + '{'
api_code += f"""{self.gene_ad_func_call()}"""
api_code += '}' + '\n'
return api_code
......@@ -25,6 +25,13 @@ using IntArray =
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
if (x_grad) {
auto res = cast<T>(out_grad, dtype);
set_output<T>(res, x_grad);
}
}
template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
......@@ -65,6 +72,29 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
set_output<T>(grad_x_tmp, grad_x);
}
template <typename T>
void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
if (grad_x) {
auto grad_x_tmp = reshape<T>(grad_out, phi::vectorize(x.dims()));
set_output<T>(grad_x_tmp, grad_x);
}
}
template <typename T>
void transpose_grad(const Tensor& grad_out,
const std::vector<int>& perm,
Tensor* grad_x) {
if (grad_x) {
std::vector<int> reverse_perm(perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(perm.size()); ++i) {
reverse_perm[perm[i]] = i;
}
auto grad_x_tmp = transpose<T>(grad_out, reverse_perm);
set_output<T>(grad_x_tmp, grad_x);
}
}
template <typename T>
void subtract_grad(const Tensor& x,
const Tensor& y,
......
......@@ -33,5 +33,9 @@ Tensor full<Tensor>(const IntArray& shape,
VLOG(4) << "Eager Prim API full_ad_func call";
return ::full_ad_func(shape, value, dtype, place);
}
template <>
Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype);
}
} // namespace prim
} // namespace paddle
......@@ -36,5 +36,7 @@ Tensor full(const IntArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
const Place& place = CPUPlace());
template <typename T>
Tensor cast(const Tensor& x, DataType dtype);
} // namespace prim
} // namespace paddle
......@@ -120,6 +120,22 @@ Tensor full<DescTensor>(const IntArray& shape,
op->InferShape(*block);
return out;
}
template <>
Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("cast");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("in_dtype", static_cast<int>(x.dtype()));
op->SetAttr("out_dtype", static_cast<int>(dtype));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim
} // namespace paddle
......@@ -180,6 +180,7 @@
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
invoke : cast (out_grad, x.dtype())
composite: cast_grad(x, out_grad)
no_need_buffer : x
- backward_op : channel_shuffle_grad
......@@ -378,7 +379,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(x, y, out, out_grad, -1)
composite : divide_grad(x, y, out, out_grad, axis)
backward : divide_double_grad
- backward_op : dropout_grad
......@@ -1342,6 +1343,7 @@
kernel :
func : transpose_grad
backward : transpose_double_grad
composite: transpose_grad(out_grad, perm)
- backward_op : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out)
......
......@@ -53,19 +53,13 @@ core._set_prim_backward_enabled(True)
),
],
)
class TestTanhGradComp(unittest.TestCase):
class TestAddGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def test_add_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
paddle.disable_static()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal', 'cotangent', 'src_dtype', 'dst_type'),
[
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float32,
np.float64,
),
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float64,
np.float32,
),
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float32,
np.float32,
),
],
)
class TestCastGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.src_dtype)
cls.cotangent = cls.cotangent.astype(cls.src_dtype)
def test_cast_grad_comp(self):
core._set_prim_backward_enabled(True)
def actual(primal, cotangent):
x = paddle.to_tensor(primal)
x.stop_gradient = False
v = paddle.to_tensor(cotangent)
y = paddle.cast(x, self.dst_type)
x_cotangent = paddle.grad(y, x, v)
return x_cotangent
def desired(primal, cotangent):
return (cotangent * np.ones_like(primal)).astype(primal.dtype)
actual = actual(self.primal, self.cotangent)
desired = desired(self.primal, self.cotangent)
from paddle.fluid.data_feeder import _PADDLE_DTYPE_2_NUMPY_DTYPE
self.assertEqual(
_PADDLE_DTYPE_2_NUMPY_DTYPE[actual[0].dtype], desired.dtype
)
np.testing.assert_allclose(
actual=actual[0],
desired=desired,
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -53,19 +53,13 @@ core._set_prim_backward_enabled(True)
),
],
)
class TestTanhGradComp(unittest.TestCase):
class TestDivGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def test_div_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
paddle.disable_static()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
@param.parameterized_class(
('primal', 'shape', 'cotangent', 'dtype'),
[
(
np.random.rand(10, 1, 10),
[10, 10],
np.random.rand(10, 10),
np.float32,
),
(np.random.rand(2, 60), [12, 10], np.random.rand(12, 10), np.float32),
],
)
class TestReshapeGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
def test_reshape_grad_comp(self):
def actual(primal0, shape):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.reshape(x, shape)
res = paddle.grad(out, [x], create_graph=True, retain_graph=True)
return res[0].numpy()
def desired(primal0, shape):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.reshape(x, shape)
res = paddle.grad(out, [x], create_graph=True, retain_graph=True)
return res[0].numpy()
dx = actual(self.primal, self.shape)
ddx = desired(self.primal, self.shape)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -37,12 +37,6 @@ class TestSqrtGradComp(unittest.TestCase):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_sqrt_grad_comp(self):
def actual(primal, cotangent):
paddle.disable_static()
......
......@@ -53,19 +53,13 @@ core._set_prim_backward_enabled(True)
),
],
)
class TestTanhGradComp(unittest.TestCase):
class TestSubGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def test_sub_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
paddle.disable_static()
......
......@@ -41,12 +41,6 @@ class TestTanhGradComp(unittest.TestCase):
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal):
paddle.disable_static()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core._set_prim_backward_enabled(True)
@param.parameterized_class(
('primal', 'axis', 'cotangent', 'dtype'),
[
(
np.random.rand(
100,
),
[0],
np.random.rand(100),
np.float32,
),
(
np.random.rand(3, 4, 10),
[0, 2, 1],
np.random.rand(3, 10, 4),
np.float32,
),
(
np.random.rand(2, 3, 4, 5),
[0, 2, 3, 1],
np.random.rand(2, 4, 5, 3),
np.float32,
),
(
np.random.rand(2, 3, 4, 5, 6),
[4, 2, 3, 1, 0],
np.random.rand(6, 4, 5, 3, 2),
np.float32,
),
(
np.random.rand(2, 3, 4, 5, 6, 1),
[4, 2, 3, 1, 0, 5],
np.random.rand(6, 4, 5, 3, 2, 1),
np.float32,
),
# (np.random.rand(),
# [],
# np.random.rand(),
# np.float32),
],
)
class TestTransposeGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
if isinstance(cls.primal, np.ndarray):
cls.primal = cls.primal.astype(cls.dtype)
def test_transpose_grad_comp(self):
def actual(primal0, shape):
core._set_prim_backward_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.transpose(x, shape)
res = paddle.grad(out, [x], create_graph=True, retain_graph=True)
return res[0].numpy()
def desired(primal0, shape):
core._set_prim_backward_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
x.stop_gradient = False
out = paddle.transpose(x, shape)
res = paddle.grad(out, [x], create_graph=True, retain_graph=True)
return res[0].numpy()
dx = actual(self.primal, self.axis)
ddx = desired(self.primal, self.axis)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
tmp = self.fc(x)
out = paddle.cast(tmp, paddle.float64)
return out
@param.parameterized_class(
('primal', 'cotangent', 'src_dtype', 'dst_type'),
[
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float32,
np.float64,
),
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float64,
np.float32,
),
(
np.random.rand(10, 10),
np.random.rand(10, 10),
np.float32,
np.float32,
),
],
)
class TestCastGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.src_dtype)
cls.cotangent = cls.cotangent.astype(cls.src_dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def train(self, use_prim, use_cinn):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
net = PrimeNet()
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x)
res = paddle.autograd.grad(out, [self.x])
return res
def test_cinn(self):
paddle.disable_static()
use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed
use_cinn = False
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)
for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-15,
atol=1e-15,
)
paddle.enable_static()
def test_cast_grad_comp(self):
core._set_prim_backward_enabled(True)
def actual(primal, cotangent):
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
y = paddle.cast(x, self.dst_type)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=mp.blocks[0].ops[-1].output('Out')[0],
)[0]
def desired(primal, cotangent):
return (cotangent * np.ones_like(primal)).astype(primal.dtype)
actual = actual(self.primal, self.cotangent)
desired = desired(self.primal, self.cotangent)
self.assertEqual(actual.dtype, desired.dtype)
np.testing.assert_allclose(
actual=actual,
desired=desired,
rtol=1e-6,
atol=0,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
tmp = self.fc(x)
out = paddle.reshape(tmp, [2, 1, 4])
return out
@param.parameterized_class(
('primal', 'shape', 'cotangent', 'dtype', "rtol"),
[
(
np.random.rand(10, 1, 10),
[10, 10],
np.random.rand(10, 10),
np.float32,
1e-5,
),
(
np.random.rand(2, 60),
[12, 10],
np.random.rand(12, 10),
np.float32,
1e-5,
),
(
np.random.rand(10, 1, 10),
[10, 10],
np.random.rand(10, 10),
np.float64,
1e-15,
),
(
np.random.rand(2, 60),
[12, 10],
np.random.rand(12, 10),
np.float64,
1e-15,
),
(
np.random.rand(10, 1, 10),
[10, 10],
np.random.rand(10, 10),
np.float16,
1e-3,
),
(
np.random.rand(2, 60),
[12, 10],
np.random.rand(12, 10),
np.float16,
1e-3,
),
],
)
class TestReshapeGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal = cls.primal.astype(cls.dtype)
cls.cotangent = cls.cotangent.astype(cls.dtype)
def train(self, use_prim, use_cinn):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
net = PrimeNet()
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x)
res = paddle.autograd.grad(out, [self.x])
return res
def test_cinn(self):
paddle.disable_static()
use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed
use_cinn = False
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)
for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
paddle.enable_static()
def test_reshape_grad_comp(self):
def actual(primal, shape, cotangent):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
y = paddle.reshape(x, shape)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
)[0]
def desired(primal, shape, cotangent):
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal', primal.shape, primal.dtype)
x.stop_gradient = False
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
y = paddle.reshape(x, shape)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
)[0]
if (self.dtype == np.float16) and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# reshape doesn't support fp16 kernel in cpu
pass
else:
np.testing.assert_allclose(
actual=actual(self.primal, self.shape, self.cotangent),
desired=desired(self.primal, self.shape, self.cotangent),
rtol=self.rtol,
atol=self.rtol,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -69,7 +69,7 @@ class PrimeNet(paddle.nn.Layer):
),
],
)
class TestDivGradComp(unittest.TestCase):
class TestSubGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core, framework
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
def forward(self, x):
out = paddle.transpose(x, [0, 2, 1])
return out
@param.parameterized_class(
('primal', 'axis', 'cotangent', 'dtype', 'rtol'),
[
(
np.random.rand(
100,
),
[0],
np.random.rand(100),
np.float64,
1e-15,
),
(
np.random.rand(3, 4, 10),
[0, 2, 1],
np.random.rand(3, 10, 4),
np.float64,
1e-15,
),
(
np.random.rand(2, 3, 4, 5),
[0, 2, 3, 1],
np.random.rand(2, 4, 5, 3),
np.float64,
1e-15,
),
(
np.random.rand(2, 3, 4, 5, 6),
[4, 2, 3, 1, 0],
np.random.rand(6, 4, 5, 3, 2),
np.float64,
1e-15,
),
(
np.random.rand(2, 3, 4, 5, 6, 1),
[4, 2, 3, 1, 0, 5],
np.random.rand(6, 4, 5, 3, 2, 1),
np.float64,
1e-15,
),
(
np.random.rand(
100,
),
[0],
np.random.rand(100),
np.float16,
1e-3,
),
(
np.random.rand(3, 4, 10),
[0, 2, 1],
np.random.rand(3, 10, 4),
np.float16,
1e-3,
),
(
np.random.rand(2, 3, 4, 5),
[0, 2, 3, 1],
np.random.rand(2, 4, 5, 3),
np.float16,
1e-3,
),
(
np.random.rand(2, 3, 4, 5, 6),
[4, 2, 3, 1, 0],
np.random.rand(6, 4, 5, 3, 2),
np.float16,
1e-3,
),
(
np.random.rand(2, 3, 4, 5, 6, 1),
[4, 2, 3, 1, 0, 5],
np.random.rand(6, 4, 5, 3, 2, 1),
np.float16,
1e-3,
),
],
)
class TestTransposeGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
if isinstance(cls.primal, np.ndarray):
cls.primal = cls.primal.astype(cls.dtype)
if isinstance(cls.cotangent, np.ndarray):
cls.cotangent = cls.cotangent.astype(cls.dtype)
def train(self, use_prim, use_cinn):
paddle.seed(2022)
self.x = paddle.randn([3, 4, 10])
self.x.stop_gradient = False
net = PrimeNet()
core._set_prim_backward_enabled(use_prim)
net = apply_to_static(net, use_cinn)
out = net(self.x)
res = paddle.autograd.grad(out, [self.x])
return res
def _test_cinn(self):
paddle.disable_static()
use_cinn = True
if isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# TODO(jiabin): CINN will crashed in this case open it when fixed
use_cinn = False
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn)
for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_transpose_grad_comp(self):
paddle.enable_static()
def actual(primal, axis, cotangent):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
if isinstance(primal, np.ndarray):
x = paddle.static.data('primal', primal.shape, primal.dtype)
else:
x = paddle.static.data('primal', [1], "float32")
x.stop_gradient = False
if isinstance(cotangent, np.ndarray):
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
else:
v = paddle.static.data('cotangent', [1], "float32")
print(x.shape)
y = paddle.transpose(x, axis)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
)[0]
def desired(primal, axis, cotangent):
core._set_prim_backward_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
if isinstance(primal, np.ndarray):
x = paddle.static.data('primal', primal.shape, primal.dtype)
else:
x = paddle.static.data('primal', [1], "float32")
x.stop_gradient = False
if isinstance(cotangent, np.ndarray):
v = paddle.static.data(
'cotangent', cotangent.shape, cotangent.dtype
)
else:
v = paddle.static.data('cotangent', [1], "float32")
y = paddle.transpose(x, axis)
x_cotangent = paddle.static.gradients(y, x, v)
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
)[0]
if (self.dtype == np.float16) and isinstance(
framework._current_expected_place(), framework.core.CPUPlace
):
# reshape doesn't support fp16 kernel in cpu.
pass
else:
np.testing.assert_allclose(
actual=actual(self.primal, self.axis, self.cotangent),
desired=desired(self.primal, self.axis, self.cotangent),
rtol=self.rtol,
atol=self.rtol,
)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册