未验证 提交 1c7001e7 编写于 作者: H hong 提交者: GitHub

Add dropout yaml (#41355)

* add dropout slice yaml

* remove useless code

* fix infer shape error

* skip infrt compile for dropout
上级 119816f9
...@@ -777,10 +777,17 @@ void OpDesc::CheckAttrs() { ...@@ -777,10 +777,17 @@ void OpDesc::CheckAttrs() {
checker->Check(&attrs_); checker->Check(&attrs_);
} }
void OpDesc::InferShape(const BlockDesc &block) const { void OpDesc::InferShape(const BlockDesc &block) {
try { try {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(3) << "CompileTime infer shape on " << Type();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &op_info = OpInfoMap::Instance().Get(this->Type());
auto *checker = op_info.Checker();
if (checker != nullptr) {
// set dafault value here
VLOG(10) << "begin to check attribute of " << Type();
checker->Check(&attrs_);
}
auto &infer_shape = op_info.infer_shape_;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<bool>(infer_shape), true, static_cast<bool>(infer_shape), true,
platform::errors::NotFound( platform::errors::NotFound(
......
...@@ -142,7 +142,7 @@ class OpDesc { ...@@ -142,7 +142,7 @@ class OpDesc {
void CheckAttrs(); void CheckAttrs();
void InferShape(const BlockDesc &block) const; void InferShape(const BlockDesc &block);
void InferVarType(BlockDesc *block) const; void InferVarType(BlockDesc *block) const;
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -776,6 +776,26 @@ void DistInferMeta(const MetaTensor& x, ...@@ -776,6 +776,26 @@ void DistInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void DropoutInferMeta(const MetaTensor& x,
paddle::optional<const MetaTensor&> seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
MetaTensor* out,
MetaTensor* mask) {
auto x_dims = x.dims();
out->set_dims(x_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
if (mask != nullptr) {
mask->set_dims(x_dims);
mask->set_dtype(DataType::UINT8);
}
}
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto x_rank = static_cast<size_t>(x_dims.size()); auto x_rank = static_cast<size_t>(x_dims.size());
......
...@@ -124,6 +124,16 @@ void DistInferMeta(const MetaTensor& x, ...@@ -124,6 +124,16 @@ void DistInferMeta(const MetaTensor& x,
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void DropoutInferMeta(const MetaTensor& x,
paddle::optional<const MetaTensor&> seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
MetaTensor* out,
MetaTensor* mask);
void ElementwiseInferMeta(const MetaTensor& x, void ElementwiseInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out); MetaTensor* out);
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/funcs/strided_slice.h" #include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
...@@ -360,17 +361,6 @@ void DiagonalInferMeta(const MetaTensor& input, ...@@ -360,17 +361,6 @@ void DiagonalInferMeta(const MetaTensor& input,
out->set_dims(phi::make_ddim(out_dims)); out->set_dims(phi::make_ddim(out_dims));
} }
void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask) {
auto x_dims = x.dims();
out->set_dims(x_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
if (mask != nullptr) {
mask->set_dims(x_dims);
}
}
void EighInferMeta(const MetaTensor& x, void EighInferMeta(const MetaTensor& x,
const std::string& uplo, const std::string& uplo,
MetaTensor* out_w, MetaTensor* out_w,
...@@ -1738,6 +1728,51 @@ void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { ...@@ -1738,6 +1728,51 @@ void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dims({1}); out->set_dims({1});
} }
void SliceRawInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts_arr,
const IntArray& ends_arr,
const std::vector<int64_t>& infer_flags_t,
const std::vector<int64_t>& decrease_axis,
MetaTensor* out,
MetaConfig config) {
auto in_dims = input.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument("The rank of input should be less than 7."));
DDim out_dims(in_dims);
std::vector<int64_t> infer_flags = infer_flags_t;
if (infer_flags.empty()) {
// Initialize infer_flags with 1.
// To be compatible with other op tests in which infer_flags is not set.
infer_flags = std::vector<int64_t>(axes.size(), 1);
}
// 2.1 Check attrs.
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();
phi::funcs::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, &starts, &ends, nullptr, &infer_flags);
auto slice_dims = phi::funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, &infer_flags);
if (config.is_runtime) {
out_dims = phi::funcs::GetDecreasedDims<int64_t>(
slice_dims, decrease_axis, &infer_flags);
} else {
out_dims = phi::funcs::GetDecreasedDims<int64_t>(
slice_dims, decrease_axis, nullptr);
}
out->set_dims(out_dims);
if (axes.size() > 0 && axes[0] != 0) {
out->share_lod(input);
}
}
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) { void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) {
auto dim_x = x.dims(); auto dim_x = x.dims();
auto rank_x = dim_x.size(); auto rank_x = dim_x.size();
......
...@@ -80,8 +80,6 @@ void DiagInferMeta(const MetaTensor& x, ...@@ -80,8 +80,6 @@ void DiagInferMeta(const MetaTensor& x,
void DiagonalInferMeta( void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask);
void EighInferMeta(const MetaTensor& x, void EighInferMeta(const MetaTensor& x,
const std::string& uplo, const std::string& uplo,
MetaTensor* out_w, MetaTensor* out_w,
...@@ -271,6 +269,15 @@ void ShardIndexInferMeta(const MetaTensor& in, ...@@ -271,6 +269,15 @@ void ShardIndexInferMeta(const MetaTensor& in,
void SizeInferMeta(const MetaTensor& input, MetaTensor* out); void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void SliceRawInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out); void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out);
void SplitInferMeta(const MetaTensor& x_meta, void SplitInferMeta(const MetaTensor& x_meta,
......
...@@ -1337,6 +1337,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -1337,6 +1337,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
continue continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
# infer_shape and infer_type # infer_shape and infer_type
op_desc.check_attrs()
op_desc.infer_var_type(block.desc) op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc) op_desc.infer_shape(block.desc)
......
...@@ -5141,7 +5141,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): ...@@ -5141,7 +5141,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
# [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]] # [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]]
""" """
if len(x.shape) == 1: if len(x.shape) == 1:
axis = 0 axis = 0
if _non_static_mode(): if _non_static_mode():
...@@ -11199,18 +11198,15 @@ def slice(input, axes, starts, ends): ...@@ -11199,18 +11198,15 @@ def slice(input, axes, starts, ends):
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
tmp_tensor_type = core.eager.Tensor tmp_tensor_type = core.eager.Tensor
if isinstance(starts, (list, tuple)): if isinstance(starts, (list, tuple)):
starts = [ starts = [
item.numpy().item(0) item.numpy().item(0)
if isinstance(item, tmp_tensor_type) else item if isinstance(item, tmp_tensor_type) else item
for item in starts for item in starts
] ]
attrs += ('starts', starts)
elif isinstance(starts, tmp_tensor_type): elif isinstance(starts, tmp_tensor_type):
starts_tensor = starts tensor_t = starts.numpy()
starts.stop_gradient = True starts = [ele for ele in tensor_t]
infer_flags = list(-1 for i in range(len(axes)))
if isinstance(ends, (list, tuple)): if isinstance(ends, (list, tuple)):
ends = [ ends = [
...@@ -11219,12 +11215,11 @@ def slice(input, axes, starts, ends): ...@@ -11219,12 +11215,11 @@ def slice(input, axes, starts, ends):
] ]
attrs += ('ends', ends) attrs += ('ends', ends)
elif isinstance(ends, tmp_tensor_type): elif isinstance(ends, tmp_tensor_type):
ends_tensor = ends tensor_t = ends.numpy()
ends_tensor.stop_gradient = True ends = [ele for ele in tensor_t]
infer_flags = list(-1 for i in range(len(axes)))
return _C_ops.slice(input, starts_tensor, ends_tensor, None, None, return _C_ops.final_state_slice(input, axes, starts, ends, infer_flags,
'axes', axes, 'infer_flags', infer_flags, *attrs) [])
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
attrs = () attrs = ()
......
...@@ -22,8 +22,11 @@ import paddle ...@@ -22,8 +22,11 @@ import paddle
import paddle.static as static import paddle.static as static
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
import os import os
from paddle import _C_ops
class TestDropoutOp(OpTest): class TestDropoutOp(OpTest):
def setUp(self): def setUp(self):
...@@ -960,6 +963,19 @@ class TestDropoutBackward(unittest.TestCase): ...@@ -960,6 +963,19 @@ class TestDropoutBackward(unittest.TestCase):
np.array_equal(input.gradient( np.array_equal(input.gradient(
), self.cal_grad_downscale_in_infer(mask.numpy()))) ), self.cal_grad_downscale_in_infer(mask.numpy())))
def test_backward_downscale_in_infer_eager(self):
for place in self.places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.final_state_dropout(
input, None, 0.5, False, "downgrade_in_infer", 0, False)
out.backward()
self.assertTrue(
np.array_equal(input.gradient(
), self.cal_grad_downscale_in_infer(mask.numpy())))
def test_backward_upscale_train(self): def test_backward_upscale_train(self):
for place in self.places: for place in self.places:
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
...@@ -976,6 +992,21 @@ class TestDropoutBackward(unittest.TestCase): ...@@ -976,6 +992,21 @@ class TestDropoutBackward(unittest.TestCase):
np.allclose(input.gradient( np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob))) ), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_eager(self):
for place in self.places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
prob = 0.5
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.final_state_dropout(
input, None, 0.5, False, "upscale_in_train", 0, False)
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_2(self): def test_backward_upscale_train_2(self):
for place in self.places: for place in self.places:
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
......
...@@ -21,6 +21,7 @@ from op_test import OpTest, convert_float_to_uint16 ...@@ -21,6 +21,7 @@ from op_test import OpTest, convert_float_to_uint16
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static() paddle.enable_static()
...@@ -599,6 +600,31 @@ class TestSliceApiWithTensor(unittest.TestCase): ...@@ -599,6 +600,31 @@ class TestSliceApiWithTensor(unittest.TestCase):
self.assertTrue(np.array_equal(y_paddle.numpy(), y_np)) self.assertTrue(np.array_equal(y_paddle.numpy(), y_np))
class TestSliceApiEager(unittest.TestCase):
def test_slice_api(self):
with paddle.fluid.dygraph.guard():
with _test_eager_guard():
a = paddle.rand(shape=[4, 5, 6], dtype='float32')
a.stop_gradient = False
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
a_1 = paddle.slice(a, axes=axes, starts=starts, ends=ends)
a_2 = paddle.slice(
a,
axes=axes,
starts=paddle.to_tensor(starts),
ends=paddle.to_tensor(ends))
a_1.backward()
grad_truth = paddle.zeros_like(a)
grad_truth[-3:3, 0:2, 2:4] = 1
self.assertTrue(np.array_equal(grad_truth, a.gradient()))
self.assertTrue(np.allclose(a_1.numpy(), a[-3:3, 0:2, 2:4]))
class TestSliceApiWithLoDTensorArray(unittest.TestCase): class TestSliceApiWithLoDTensorArray(unittest.TestCase):
def setUp(self): def setUp(self):
self.shape = (3, 4) self.shape = (3, 4)
......
...@@ -28,7 +28,7 @@ from ...tensor import clip ...@@ -28,7 +28,7 @@ from ...tensor import clip
from ...tensor import sum from ...tensor import sum
from ...tensor import sqrt from ...tensor import sqrt
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from ...fluid import dygraph_utils from ...fluid import dygraph_utils
from ...fluid import layers from ...fluid import layers
...@@ -895,9 +895,15 @@ def dropout(x, ...@@ -895,9 +895,15 @@ def dropout(x,
seed = None seed = None
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
if in_dynamic_mode(): if _non_static_mode():
if default_main_program().random_seed != 0: if default_main_program().random_seed != 0:
seed = default_main_program().random_seed seed = default_main_program().random_seed
if in_dygraph_mode():
out, mask = _C_ops.final_state_dropout( x, None, p, not training, mode, \
seed if seed is not None else 0, seed is not None)
return out
out, mask = _C_ops.dropout( out, mask = _C_ops.dropout(
x, 'dropout_prob', p, 'is_test', not training, 'fix_seed', x, 'dropout_prob', p, 'is_test', not training, 'fix_seed',
seed is not None, 'seed', seed seed is not None, 'seed', seed
......
...@@ -463,6 +463,16 @@ ...@@ -463,6 +463,16 @@
kernel : kernel :
func : dot func : dot
- api : dropout
args : (Tensor x, Tensor seed_tensor, float p, bool is_test, str mode, int seed, bool fix_seed)
output : Tensor(out), Tensor(mask)
infer_meta :
func : DropoutInferMeta
kernel :
func : dropout
optional : seed_tensor
backward : dropout_grad
# eigh # eigh
- api : eigh - api : eigh
args : (Tensor x, str uplo) args : (Tensor x, str uplo)
...@@ -1504,6 +1514,15 @@ ...@@ -1504,6 +1514,15 @@
kernel : kernel :
func : size func : size
- api : slice
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
infer_meta :
func : SliceRawInferMeta
kernel :
func : slice
backward : slice_grad
# soft_shrink # soft_shrink
- api : soft_shrink - api : soft_shrink
args : (Tensor x, float lambda) args : (Tensor x, float lambda)
......
...@@ -301,6 +301,17 @@ ...@@ -301,6 +301,17 @@
kernel : kernel :
func : divide_grad func : divide_grad
- backward_api : dropout_grad
forward : dropout (Tensor x, Tensor seed_tensor, float p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(mask)
args : (Tensor mask, Tensor out_grad, float p, bool is_test, str mode)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : dropout_grad
optional : seed_tensor
- backward_api : eigh_grad - backward_api : eigh_grad
forward : eigh (Tensor x, str uplo) -> Tensor(out_w), Tensor(out_v) forward : eigh (Tensor x, str uplo) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad) args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
...@@ -1054,6 +1065,16 @@ ...@@ -1054,6 +1065,16 @@
kernel : kernel :
func : sinh_grad func : sinh_grad
- backward_api : slice_grad
forward : slice (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor(input_grad)
infer_meta :
func : UnchangedInferMeta
param : [input]
kernel :
func : slice_grad
- backward_api : soft_shrink_grad - backward_api : soft_shrink_grad
forward : soft_shrink (Tensor x, float lambda) -> Tensor(out) forward : soft_shrink (Tensor x, float lambda) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float lambda) args : (Tensor x, Tensor out_grad, float lambda)
......
{ {
"phi_apis":["conj", "nll_loss", "flatten"], "phi_apis":["conj", "nll_loss", "dropout", "flatten"],
"phi_kernels":["equal_all"] "phi_kernels":["equal_all"]
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册