未验证 提交 bfeedd29 编写于 作者: M mengziheng 提交者: GitHub

Pad grad (#53374)

* add pad op

* add_some_code

* modify some code

* add some code

* add some code

* modify some code

* add some code

* modify some code

* Update composite_backward_api.h

* modify some code

* add some code

* add some code

* add some code
上级 0b6dd535
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
...@@ -129,6 +132,27 @@ class PadOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -129,6 +132,27 @@ class PadOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class PadCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
std::vector<int> paddings =
static_cast<std::vector<int>>(this->Attr<std::vector<int>>("paddings"));
float pad_value = static_cast<float>(this->Attr<float>("pad_value"));
VLOG(6) << "Runing add_grad composite func";
prim::pad_grad<prim::DescTensor>(x, out_grad, paddings, pad_value, dx_ptr);
this->RecoverOutputName(x_grad, dx_name);
}
};
template <typename T> template <typename T>
class PadOpDoubleGradMaker : public framework::SingleGradOpMaker<T> { class PadOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -155,6 +179,7 @@ REGISTER_OPERATOR(pad, ...@@ -155,6 +179,7 @@ REGISTER_OPERATOR(pad,
ops::PadOpMaker, ops::PadOpMaker,
ops::PadOpGradMaker<paddle::framework::OpDesc>, ops::PadOpGradMaker<paddle::framework::OpDesc>,
ops::PadOpGradMaker<paddle::imperative::OpBase>, ops::PadOpGradMaker<paddle::imperative::OpBase>,
ops::PadCompositeGradOpMaker,
PadInferShapeFunctor); PadInferShapeFunctor);
REGISTER_OPERATOR(pad_grad, REGISTER_OPERATOR(pad_grad,
ops::PadOpGrad, ops::PadOpGrad,
......
...@@ -1806,6 +1806,32 @@ void roll_grad(const Tensor& x, ...@@ -1806,6 +1806,32 @@ void roll_grad(const Tensor& x,
} }
} }
template <typename T>
void pad_grad(const Tensor& input,
const Tensor& out_grad,
const std::vector<int>& paddings,
const Scalar& pad_value,
Tensor* input_grad) {
if (input_grad) {
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
std::vector<int> starts(rank, 0);
std::vector<int64_t> ends(rank, 0);
std::vector<int64_t> axes(rank, 0);
std::vector<int64_t> infer_flags(rank, 1);
std::vector<int64_t> decrease_axis({});
for (size_t i = 0; i < rank; ++i) {
starts.push_back(static_cast<int>(paddings[2 * i]));
ends.push_back(static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]));
axes.push_back(i);
}
auto out_tmp =
slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis);
set_output<T>(out_tmp, input_grad);
}
}
template <typename T> template <typename T>
void scatter_nd_add_grad(const Tensor& index, void scatter_nd_add_grad(const Tensor& index,
const Tensor& updates, const Tensor& updates,
...@@ -1821,5 +1847,6 @@ void scatter_nd_add_grad(const Tensor& index, ...@@ -1821,5 +1847,6 @@ void scatter_nd_add_grad(const Tensor& index,
set_output<T>(tmp_updates_grad, updates_grad); set_output<T>(tmp_updates_grad, updates_grad);
} }
} }
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -33,5 +33,16 @@ Tensor cast<Tensor>(const Tensor& x, DataType dtype) { ...@@ -33,5 +33,16 @@ Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype); return ::cast_ad_func(x, dtype);
} }
template <>
Tensor slice<Tensor>(const Tensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
VLOG(4) << "Eager Prim API slice_ad_func call";
return ::slice_ad_func(input, axes, starts, ends, infer_flags, decrease_axis);
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -38,5 +38,13 @@ Tensor full(const IntArray& shape, ...@@ -38,5 +38,13 @@ Tensor full(const IntArray& shape,
template <typename T> template <typename T>
Tensor cast(const Tensor& x, DataType dtype); Tensor cast(const Tensor& x, DataType dtype);
template <typename T>
Tensor slice(const Tensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -127,5 +127,32 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) { ...@@ -127,5 +127,32 @@ Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
return out; return out;
} }
template <>
Tensor slice<DescTensor>(const Tensor& input,
const std::vector<int64_t>& axes,
const IntArray& starts,
const IntArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("slice");
op->SetInput(
"Input",
{std::static_pointer_cast<prim::DescTensor>(input.impl())->Name()});
auto out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->SetAttr("axes", unsafe_vector_cast<int64_t, int>(axes));
op->SetAttr("starts", unsafe_vector_cast<int64_t, int>(starts.GetData()));
op->SetAttr("ends", unsafe_vector_cast<int64_t, int>(ends.GetData()));
op->SetAttr("infer_flags", unsafe_vector_cast<int64_t, int>(infer_flags));
op->SetAttr("decrease_axis", unsafe_vector_cast<int64_t, int>(decrease_axis));
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -722,6 +722,7 @@ ...@@ -722,6 +722,7 @@
func : pad_grad func : pad_grad
param: [out_grad, paddings, pad_value] param: [out_grad, paddings, pad_value]
no_need_buffer : x no_need_buffer : x
composite : pad_grad(x, out_grad, paddings, pad_value, x_grad)
backward : pad_double_grad backward : pad_double_grad
- backward_op : pool2d_double_grad - backward_op : pool2d_double_grad
......
...@@ -24,7 +24,9 @@ from paddle.fluid import Program, core, program_guard ...@@ -24,7 +24,9 @@ from paddle.fluid import Program, core, program_guard
def pad_wrapper(x, paddings, pad_value): def pad_wrapper(x, paddings, pad_value):
return paddle._C_ops.pad(x, paddings, float(pad_value)) return paddle.nn.functional.pad(
x, pad=list(paddings), mode='constant', value=pad_value
)
class TestPadOp(OpTest): class TestPadOp(OpTest):
...@@ -37,7 +39,7 @@ class TestPadOp(OpTest): ...@@ -37,7 +39,7 @@ class TestPadOp(OpTest):
'X': np.random.random(self.shape).astype(self.dtype), 'X': np.random.random(self.shape).astype(self.dtype),
} }
self.attrs = {} self.attrs = {}
self.attrs['paddings'] = np.array(self.paddings).flatten() self.attrs['paddings'] = list(np.array(self.paddings).flatten())
self.attrs['pad_value'] = self.pad_value self.attrs['pad_value'] = self.pad_value
self.outputs = { self.outputs = {
'Out': np.pad( 'Out': np.pad(
...@@ -47,6 +49,9 @@ class TestPadOp(OpTest): ...@@ -47,6 +49,9 @@ class TestPadOp(OpTest):
constant_values=self.pad_value, constant_values=self.pad_value,
) )
} }
self.prim_op_type = "prim"
self.public_python_api = pad_wrapper
self.enable_cinn = False
def get_dtype(self): def get_dtype(self):
return np.float64 return np.float64
...@@ -55,7 +60,7 @@ class TestPadOp(OpTest): ...@@ -55,7 +60,7 @@ class TestPadOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def initTestCase(self): def initTestCase(self):
self.shape = (16, 16) self.shape = (16, 16)
...@@ -111,16 +116,19 @@ create_test_fp16(TestCase3) ...@@ -111,16 +116,19 @@ create_test_fp16(TestCase3)
class TestPadOpError(unittest.TestCase): class TestPadOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with paddle.fluid.framework._static_guard():
input_data = np.random.random((2, 2)).astype("float32") with program_guard(Program(), Program()):
input_data = np.random.random((2, 2)).astype("float32")
def test_Variable(): def test_Variable():
paddle.nn.functional.pad(x=input_data, pad=[1, 1, 1, 1]) paddle.nn.functional.pad(x=input_data, pad=[1, 1, 1, 1])
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
data = paddle.static.data(name='data', shape=[4], dtype='float16') data = paddle.static.data(
paddle.nn.functional.pad(x=data, pad=[0, 1]) name='data', shape=[4], dtype='float16'
)
paddle.nn.functional.pad(x=data, pad=[0, 1])
class TestPaddingValueTensor(UnittestBase): class TestPaddingValueTensor(UnittestBase):
...@@ -129,34 +137,40 @@ class TestPaddingValueTensor(UnittestBase): ...@@ -129,34 +137,40 @@ class TestPaddingValueTensor(UnittestBase):
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def test_static(self): def test_static(self):
main_prog = Program() with paddle.fluid.framework._static_guard():
starup_prog = Program() main_prog = Program()
with program_guard(main_prog, starup_prog): starup_prog = Program()
fc = paddle.nn.Linear(4, 10) with program_guard(main_prog, starup_prog):
x = paddle.randn([2, 4]) fc = paddle.nn.Linear(4, 10)
x.stop_gradient = False x = paddle.randn([2, 4])
feat = fc(x) # [2,3,10] x.stop_gradient = False
feat = fc(x) # [2,3,10]
out = self.call_func(feat)
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out)) sgd = paddle.optimizer.SGD()
self.assertTrue(self.var_prefix() in str(main_prog)) sgd.minimize(paddle.mean(out))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog) exe = paddle.static.Executor()
res = exe.run(fetch_list=[feat, out]) exe.run(starup_prog)
gt = np.pad(res[0], [1, 1], 'constant', constant_values=[1.0, 1.0]) res = exe.run(fetch_list=[feat, out])
np.testing.assert_allclose(res[1], gt) gt = np.pad(
paddle.static.save_inference_model( res[0], [1, 1], 'constant', constant_values=[1.0, 1.0]
self.save_path, [x], [feat, out], exe )
) np.testing.assert_allclose(res[1], gt)
# Test for Inference Predictor paddle.static.save_inference_model(
infer_outs = self.infer_prog() self.save_path, [x], [feat, out], exe
gt = np.pad( )
infer_outs[0], [1, 1], 'constant', constant_values=[1.0, 1.0] # Test for Inference Predictor
) infer_outs = self.infer_prog()
np.testing.assert_allclose(infer_outs[1], gt) gt = np.pad(
infer_outs[0],
[1, 1],
'constant',
constant_values=[1.0, 1.0],
)
np.testing.assert_allclose(infer_outs[1], gt)
def path_prefix(self): def path_prefix(self):
return 'padding_value' return 'padding_value'
...@@ -183,23 +197,26 @@ class TestPaddingValueTensor2(TestPaddingValueTensor): ...@@ -183,23 +197,26 @@ class TestPaddingValueTensor2(TestPaddingValueTensor):
class TestPaddingValueTensor3(unittest.TestCase): class TestPaddingValueTensor3(unittest.TestCase):
def test_static(self): def test_static(self):
np_x = np.random.random((16, 16)).astype('float32') with paddle.fluid.framework._static_guard():
main_prog = Program() np_x = np.random.random((16, 16)).astype('float32')
starup_prog = Program() main_prog = Program()
with program_guard(main_prog, starup_prog): starup_prog = Program()
x = paddle.assign(np_x).astype('float32') with program_guard(main_prog, starup_prog):
pad_value = paddle.assign([0.0]).astype('float64') x = paddle.assign(np_x).astype('float32')
y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value) pad_value = paddle.assign([0.0]).astype('float64')
loss = y.sum() y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value)
optimize_ops, params_grads = paddle.optimizer.SGD(0.01).minimize( loss = y.sum()
loss optimize_ops, params_grads = paddle.optimizer.SGD(
0.01
).minimize(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
res = exe.run(
main_prog, fetch_list=[y] + [g for p, g in params_grads]
) )
pd_out = res[0]
exe = paddle.static.Executor(paddle.CPUPlace()) np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0)
res = exe.run(main_prog, fetch_list=[y] + [g for p, g in params_grads]) np.testing.assert_allclose(pd_out, np_out)
pd_out = res[0]
np_out = np.pad(np_x, [(0, 1), (2, 3)], constant_values=0.0)
np.testing.assert_allclose(pd_out, np_out)
@unittest.skipIf( @unittest.skipIf(
...@@ -215,13 +232,16 @@ class TestPadBP16Op(OpTest): ...@@ -215,13 +232,16 @@ class TestPadBP16Op(OpTest):
self.python_api = pad_wrapper self.python_api = pad_wrapper
x = np.random.random(self.shape).astype(np.float32) x = np.random.random(self.shape).astype(np.float32)
self.attrs = {} self.attrs = {}
self.attrs['paddings'] = np.array(self.paddings).flatten() self.attrs['paddings'] = list(np.array(self.paddings).flatten())
self.attrs['pad_value'] = self.pad_value self.attrs['pad_value'] = self.pad_value
out = np.pad( out = np.pad(
x, self.paddings, mode='constant', constant_values=self.pad_value x, self.paddings, mode='constant', constant_values=self.pad_value
) )
self.inputs = {'X': convert_float_to_uint16(x)} self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)} self.outputs = {'Out': convert_float_to_uint16(out)}
self.enable_cinn = False
self.prim_op_type = "prim"
self.public_python_api = pad_wrapper
def initTestCase(self): def initTestCase(self):
self.shape = (16, 16) self.shape = (16, 16)
...@@ -234,9 +254,9 @@ class TestPadBP16Op(OpTest): ...@@ -234,9 +254,9 @@ class TestPadBP16Op(OpTest):
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() # paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册