提交 7506e481 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4660 from reyoung/feature/polish_infer_shape

Polish CompileTime InferShape
...@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) ...@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -185,5 +188,38 @@ void OpDescBind::Sync() { ...@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
need_update_ = false; need_update_ = false;
} }
} }
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
}
}
return *g_map;
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -100,6 +100,8 @@ class OpDescBind { ...@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_; return &this->attrs_;
} }
void InferShape(const BlockDescBind &block) const;
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -196,7 +196,8 @@ void BindOpDesc(py::module &m) { ...@@ -196,7 +196,8 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr) .def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr) .def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr); .def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("infer_shape", &OpDescBind::InferShape);
} }
} // namespace pybind } // namespace pybind
......
...@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}) })
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
......
import unittest import unittest
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase): class TestInferShape(unittest.TestCase):
...@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"]) sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block) sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape) self.assertEqual(out.shape(), shape)
def test_mul_op(self): def test_mul_op(self):
...@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1) mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1) mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block) mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册