未验证 提交 f2972fbf 编写于 作者: T Tink_Y 提交者: GitHub

Update new_op_cn.md

fix  #132
上级 9b15d29c
......@@ -150,8 +150,9 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
//never use Input or Output if you want a to get a LoDTensor.
auto dim0 = ctx.Input<LoDTensor>("X")->dims();
auto dim1 = ctx.Input<LoDTensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X"));
......@@ -161,7 +162,7 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
ctx.Output<LoDTensor>("Out")->Resize({dim0[0], dim1[1]});
}
};
```
......@@ -201,6 +202,8 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs,
-`InferShapeContext`相比,`ExecutionContext`增加了设备类型,同样可获取到输入输出和属性参数。
- `Compute`函数里实现`OpKernel`的具体计算逻辑。
Op的输入和输出可分别通过ExecutionContext::Input()和ExecutionContext::Output()获得。注意:若op的输入/输出的变量类型是LoDTensor(fluid默认所有的Tensor默认都是LoDTensor类型),请写成ExecutionContext::Input()和ExecutionContext::Output(),不要写ExecutionContext::Input()和ExecutionContext::Output()。因为若实际的变量类型为SelectedRows,Input()和Output()方法会将SelectedRows类型特化为Tensor,导致潜在的错误。
下面是 `MulKernel` `Compute`的实现:
```cpp
......@@ -208,9 +211,9 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs,
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
auto* X = context.Input<LoDTensor>("X");
auto* Y = context.Input<LoDTensor>("Y");
auto* Z = context.Output<LoDTensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto& device_context = context.template device_context<DeviceContext>();
math::matmul<DeviceContext, T>(*X, false, *Y, false, 1, Z, 0, device_context);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册