diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 6c26183818a9d6996e3d3ce2af74ba36f4711eca..b2813da83d9e4c525e66bb1f79b28769627eaec2 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -260,6 +260,12 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } + static bool SupportGPU(const std::string& op_type) { + OperatorWithKernel::OpKernelKey key; + key.place_ = platform::GPUPlace(); + return OperatorWithKernel::AllOpKernels().at(op_type).count(key) != 0; + } + static std::shared_ptr CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index cbb86c4195a6c7e976fc5e0dd69d77be46dfb17c..d4ac8fda5411a51c7c4b86400ae5506c39a64c00 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -200,6 +200,8 @@ All parameter, weight, gradient are variables in Paddle. return OpRegistry::CreateOp(desc); }); + operator_base.def_static("support_gpu", &OpRegistry::SupportGPU); + operator_base.def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index e6bc7d8a9b5ddd4582a5ef8a47cb63a7e5911892..636828064f7d98e6d5ac80e6d1292270a62eacf3 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -28,7 +28,7 @@ class OpTestMeta(type): kwargs = dict() places = [] places.append(core.CPUPlace()) - if core.is_compile_gpu(): + if core.is_compile_gpu() and core.Operator.support_gpu(self.type): places.append(core.GPUPlace(0)) for place in places: @@ -66,7 +66,9 @@ class OpTestMeta(type): for out_name in func.all_output_args: actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] - numpy.isclose(actual, expect) + self.assertTrue( + numpy.allclose(actual, expect), + "output name: " + out_name + "has diff") obj.test_all = test_all return obj