未验证 提交 305bd25b 编写于 作者: C chengduo 提交者: GitHub

[Cherry pick] Fix register op without gradient (#19272)

* fix REGISTER_OP_WITHOUT_GRADIENT
test=develop
上级 1bb013fa
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
...@@ -53,8 +54,9 @@ class Registrar { ...@@ -53,8 +54,9 @@ class Registrar {
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) { explicit OperatorRegistrar(const char* op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), if (OpInfoMap::Instance().Has(op_type)) {
"'%s' is registered more than once.", op_type); PADDLE_THROW("'%s' is registered more than once.", op_type);
}
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass"); "OperatorRegistrar should be invoked at least by OpClass");
OpInfo info; OpInfo info;
...@@ -206,7 +208,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -206,7 +208,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
} }
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class) REGISTER_OPERATOR(op_type, op_class, op_maker_class, \
paddle::framework::EmptyGradOpMaker)
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
......
...@@ -19,7 +19,7 @@ import paddle.fluid as fluid ...@@ -19,7 +19,7 @@ import paddle.fluid as fluid
from simple_nets import init_data from simple_nets import init_data
def simple_net1(): def case1_fill_grad_vars():
x = fluid.layers.data(name='image', shape=[784], dtype='float32') x = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feature = fluid.layers.fc(input=x, size=20, act=None) feature = fluid.layers.fc(input=x, size=20, act=None)
...@@ -30,7 +30,7 @@ def simple_net1(): ...@@ -30,7 +30,7 @@ def simple_net1():
return loss return loss
def simple_net2(): def case2_prune_no_grad_branch():
x = fluid.layers.data(name='image', shape=[784], dtype='float32') x = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feature = fluid.layers.fc(input=x, size=10, act=None) feature = fluid.layers.fc(input=x, size=10, act=None)
...@@ -42,14 +42,28 @@ def simple_net2(): ...@@ -42,14 +42,28 @@ def simple_net2():
return loss return loss
def case3_prune_no_grad_branch2():
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
label = fluid.layers.cast(label, dtype="float32")
label = fluid.layers.cast(label, dtype='int64')
out = fluid.layers.one_hot(input=label, depth=100)
loss = fluid.layers.mean(out)
return loss
def case4_with_no_grad_op_maker():
out = fluid.layers.gaussian_random(shape=[20, 30])
loss = fluid.layers.mean(out)
return loss
class TestBackward(unittest.TestCase): class TestBackward(unittest.TestCase):
def check_backward(self, model): def check_backward(self, model, feed_dict):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
batch_size = 2
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
loss = model() loss = model()
...@@ -58,12 +72,16 @@ class TestBackward(unittest.TestCase): ...@@ -58,12 +72,16 @@ class TestBackward(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
img, label = init_data(batch_size, img_shape=[784], label_range=9) exe.run(feed=feed_dict)
exe.run(feed={'image': img, 'label': label})
def test_backward(self): def test_backward(self):
self.check_backward(simple_net1) batch_size = 2
self.check_backward(simple_net2) img, label = init_data(batch_size, img_shape=[784], label_range=9)
feed_dict = {'image': img, 'label': label}
self.check_backward(case1_fill_grad_vars, feed_dict)
self.check_backward(case2_prune_no_grad_branch, feed_dict)
self.check_backward(case3_prune_no_grad_branch2, {'label': label})
self.check_backward(case4_with_no_grad_op_maker, {})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册