未验证 提交 297182f7 编写于 作者: S SylarTiaNII 提交者: GitHub

add assign composite backward op (#51430)

* add assign composite backward op

* fix log msg

* code style

* fix comp rule

* replace assign with by_pass
上级 f124c86f
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
...@@ -109,6 +113,23 @@ class AssignGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -109,6 +113,23 @@ class AssignGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class AssignCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor input_grad = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
VLOG(6) << "Running assign_grad composite func";
prim::assign_grad<prim::DescTensor>(out_grad, dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
};
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
...@@ -122,6 +143,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(assign, ...@@ -122,6 +143,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(assign,
PD_INFER_META(phi::UnchangedInferMeta)); PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(assign, REGISTER_OPERATOR(assign,
ops::AssignOp, ops::AssignOp,
ops::AssignCompositeGradOpMaker,
ops::AssignGradMaker<paddle::framework::OpDesc>, ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>, ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker, ops::AssignOpProtoMaker,
......
...@@ -930,6 +930,13 @@ void gather_nd_grad(const Tensor& x, ...@@ -930,6 +930,13 @@ void gather_nd_grad(const Tensor& x,
} }
} }
template <typename T>
void assign_grad(const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
by_pass<T>(out_grad, x_grad);
}
}
template <typename T> template <typename T>
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
......
...@@ -94,6 +94,7 @@ ...@@ -94,6 +94,7 @@
forward : assign (Tensor x) -> Tensor(out) forward : assign (Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
output : Tensor(x_grad) output : Tensor(x_grad)
composite: assign_grad(out_grad, x_grad)
invoke : assign(out_grad) invoke : assign(out_grad)
- backward_op : assign_out__grad - backward_op : assign_out__grad
......
...@@ -30,6 +30,8 @@ class TestAssignOp(op_test.OpTest): ...@@ -30,6 +30,8 @@ class TestAssignOp(op_test.OpTest):
def setUp(self): def setUp(self):
self.python_api = paddle.assign self.python_api = paddle.assign
self.op_type = "assign" self.op_type = "assign"
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random(size=(100, 10)).astype('float64') x = np.random.random(size=(100, 10)).astype('float64')
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': x} self.outputs = {'Out': x}
...@@ -41,7 +43,7 @@ class TestAssignOp(op_test.OpTest): ...@@ -41,7 +43,7 @@ class TestAssignOp(op_test.OpTest):
def test_backward(self): def test_backward(self):
paddle.enable_static() paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
paddle.disable_static() paddle.disable_static()
...@@ -49,6 +51,8 @@ class TestAssignFP16Op(op_test.OpTest): ...@@ -49,6 +51,8 @@ class TestAssignFP16Op(op_test.OpTest):
def setUp(self): def setUp(self):
self.python_api = paddle.assign self.python_api = paddle.assign
self.op_type = "assign" self.op_type = "assign"
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random(size=(100, 10)).astype('float16') x = np.random.random(size=(100, 10)).astype('float16')
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': x} self.outputs = {'Out': x}
...@@ -60,7 +64,7 @@ class TestAssignFP16Op(op_test.OpTest): ...@@ -60,7 +64,7 @@ class TestAssignFP16Op(op_test.OpTest):
def test_backward(self): def test_backward(self):
paddle.enable_static() paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册