Small knowledge, big challenge! This article is participating in the creation activity of “Essential Tips for Programmers”.

Keras is a deep learning framework based on Theano/TensorFlow written in pure Python. Keras is a high-level neural network API that supports rapid experimentation and can quickly turn your ideas into results. When we understand the network structure of a model through codes, it is not easy to understand the complex structure. However, if this structure is displayed in the form of pictures, it can be more intuitive and fast for us to understand. In this paper, Keras framework is used to draw the network structure of bi-LSTM model.

I. Preliminary preparation

1. Install PyDot

pip install pydot
Copy the code

2. Install Graphviz

Graphviz should be installed on the official website: Graphvizgraphviz.org/

After the installation, you need to add system variables to the bin folder of the program directory

Write code

1. Import related packages

Load_model: Used to load network models

CRF: The CRF model layer exists in the network model

Plot_model: Generates the network model structure and saves it as a picture

Pyplot: Loads network model structure pictures

from keras.models import load_model
from keras_contrib.layers import CRF
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
Copy the code

2. Generate network model structure

Plot_model Interface parameters:

To_file: path and name for storing the network model structure picture

Show_shapes: Displays shapes (neural layer input and output)

Show_layer_names: indicates whether to display the name of the neural layer

Rankdir: Direction between neural layers. TB stands for up and down, LR stands for left and right

model_path = "./model/ch_ner_model.h5"
# model file
model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)
plot_model(model,to_file='./model/nerbilstm.png',show_shapes=True,show_layer_names='False',rankdir="TB")
Copy the code

3. Load network model structure

Use the Pyplot method in the Matplotlib package to load the generated network model structure picture.

plt.figure(figsize=(10.10))
img = plt.imread("./model/nerbilstm.png")
plt.imshow(img)
plt.axis("off")
plt.show()
Copy the code

4. Image loading results