如果想尝试使用Google Colab上的TPU来训练模型,也是非常方便,仅需添加6行代码。
在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 TPU
注:以下代码只能在Colab 上才能正确执行。
可通过以下colab链接测试效果《tf_TPU》:
https://colab.research.google.com/drive/1XCIhATyE1R7lq6uwFlYlRsUr5d9_-r1s
代码语言:javascript复制%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import *
一,准备数据
代码语言:javascript复制MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
MAX_WORDS = x_train.max() 1
CAT_NUM = y_train.max() 1
ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
.shuffle(buffer_size = 1000).batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE).cache()
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
.shuffle(buffer_size = 1000).batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE).cache()
二,定义模型
代码语言:javascript复制tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM,activation = "softmax"))
return(model)
def compile_model(model):
model.compile(optimizer=optimizers.Nadam(),
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
return(model)
三,训练模型
代码语言:javascript复制#增加以下6行代码
import os
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
model = create_model()
model.summary()
model = compile_model(model)
代码语言:javascript复制history = model.fit(ds_train,validation_data = ds_test,epochs = 10)