# 如何计算 tensorflow 和 pytorch 模型的浮点运算数

## 1. 引言

FLOPs 是 floating point operations 的缩写，指浮点运算数，可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。

## 2. 模型结构

LayerschannelsKernelsStridesUnitsActivation
Conv2D32(4,4)(1,2)relu
GRU96
Dense256sigmoid

```from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model

def test_model_tf(Input_shape):
# shape: [B, C, T, F]
main_input = Input(batch_shape=Input_shape, name="main_inputs")

conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation="relu", data_format="channels_first", name="conv")(main_input)

# shape: [B, T, FC]
gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
gru = GRU(units=96, reset_after=True, return_sequences=True, name="gru")(gru)

output = Dense(256, activation="sigmoid", name="output")(gru)

model = Model(inputs=[main_input], outputs=[output])

return model```

```import torch
import torch.nn as nn

class test_model_torch(nn.Module):
def __init__(self):
super(test_model_torch, self).__init__()

self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
self.relu = nn.ReLU()

self.gru = nn.GRU(input_size=4064, hidden_size=96)

self.fc = nn.Linear(96, 256)
self.sigmoid = nn.Sigmoid()

def forward(self, inputs):
# shape: [B, C, T, F]
out = self.conv2d(inputs)
out = self.relu(out)

# shape: [B, T, FC]
batch, channel, frame, freq = out.size()
out = torch.reshape(out, (batch, frame, freq*channel))
out, _ = self.gru(out)

out = self.fc(out)
out = self.sigmoid(out)

return out
```

## 3. 计算模型的 FLOPs

### 3.1. tensorflow 1.12.0

```import tensorflow as tf
import tensorflow.keras.backend as K

def get_flops(model):
opts = tf.profiler.ProfileOptionBuilder.float_operation()

flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd="op", options=opts)

return flops.total_float_ops

if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print("FLOPs of tensorflow 1.12.0:", get_flops(model))```

### 3.2. tensorflow 2.3.1

```import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()

def get_flops(model):
opts = tf.profiler.ProfileOptionBuilder.float_operation()

flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd="op", options=opts)

return flops.total_float_ops

if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print("FLOPs of tensorflow 2.3.1:", get_flops(model))```

### 3.3. pytorch 1.10.1+cu102

```import thop

x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print("FLOPs of pytorch 1.10.1:", flops * 2)```

### 3.4. 结果对比

tensorflow 1.12.0：

tensorflow 2.3.1：

pytorch 1.10.1：