好得很程序员自学网

<tfoot draggable='sEl'></tfoot>

CIFAR10自定义网络实战

目录

CIFAR10 MyDenseLayer

CIFAR10

MyDenseLayer

import?os
import?tensorflow?as?tf
from?tensorflow.keras?import?datasets,?layers,?optimizers,?Sequential,?metrics
from?tensorflow?import?keras

os.environ['TF_CPP_MIN_LOG_LEVEL']?=?'2'


def?preprocess(x,?y):
????#?[0,?255]?-->?[-1,1]
????x?=?2?*?tf.cast(x,?dtype=tf.float32)?/?255.?-?1
????y?=?tf.cast(y,?dtype=tf.int32)

????return?x,?y


batch_size?=?128
#?x?-->?[32,32,3],?y?-->?[10k,?1]
(x,?y),?(x_val,?y_val)?=?datasets.cifar10.load_data()
y?=?tf.squeeze(y)??#?[10k,?1]?-->?[10k]
y_val?=?tf.squeeze(y_val)
y?=?tf.one_hot(y,?depth=10)??#?[50k,?10]
y_val?=?tf.one_hot(y_val,?depth=10)??#?[10k,?10]
print('datasets:',?x.shape,?y.shape,?x_val.shape,?y_val.shape,?x.min(),
??????x.max())

train_db?=?tf.data.Dataset.from_tensor_slices((x,?y))
train_db?=?train_db.map(preprocess).shuffle(10000).batch(batch_size)
test_db?=?tf.data.Dataset.from_tensor_slices((x_val,?y_val))
test_db?=?test_db.map(preprocess).batch(batch_size)

sample?=?next(iter(train_db))
print('batch:',?sample[0].shape,?sample[1].shape)


class?MyDense(layers.Layer):
????#?to?replace?standard?layers.Dense()
????def?__init__(self,?inp_dim,?outp_dim):
????????super(MyDense,?self).__init__()

????????self.kernel?=?self.add_variable('w',?[inp_dim,?outp_dim])


#?????????self.bias?=?self.add_variable('b',?[outp_dim])

????def?call(self,?inputs,?training=None):
????????x?=?inputs?@?self.kernel
????????return?x


class?MyNetwork(keras.Model):
????def?__init__(self):
????????super(MyNetwork,?self).__init__()
????????self.fc1?=?MyDense(32?*?32?*?3,?256)
????????self.fc2?=?MyDense(256,?128)
????????self.fc3?=?MyDense(128,?64)
????????self.fc4?=?MyDense(64,?32)
????????self.fc5?=?MyDense(32,?10)

????def?call(self,?inputs,?training=None):
????????"""inputs:?[b,32,32,32,3]"""
????????x?=?tf.reshape(inputs,?[-1,?32?*?32?*?3])
????????#?[b,32*32*32]?-->?[b,?256]
????????x?=?self.fc1(x)
????????x?=?tf.nn.relu(x)
????????#?[b,?256]?-->?[b,128]
????????x?=?self.fc2(x)
????????x?=?tf.nn.relu(x)
????????#?[b,?128]?-->?[b,64]
????????x?=?self.fc3(x)
????????x?=?tf.nn.relu(x)
????????#?[b,?64]?-->?[b,32]
????????x?=?self.fc4(x)
????????x?=?tf.nn.relu(x)
????????#?[b,?32]?-->?[b,10]
????????x?=?self.fc5(x)

????????return?x


network?=?MyNetwork()
network测试数据pile(optimizer=optimizers.Adam(lr=1e-3),
????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
????????????????metrics=['accuracy'])
network.fit(train_db,?epochs=5,?validation_data=test_db,?validation_freq=1)

network.evaluate(test_db)
network.save_weights('weights.ckpt')
del?network
print('saved?to?ckpt/weights.ckpt')

network?=?MyNetwork()
network测试数据pile(optimizer=optimizers.Adam(lr=1e-3),
????????????????loss=tf.losses.CategoricalCrossentropy(from_logits=True),
????????????????metircs=['accuracy'])
network.fit(train_db,?epochs=5,?validation_data=test_db,?validation_freq=1)
network.load_weights('weights.ckpt')
print('loaded?weights?from?file.')

network.evaluate(test_db)
datasets:?(50000,?32,?32,?3)?(50000,?10)?(10000,?32,?32,?3)?(10000,?10)?0?255
batch:?(128,?32,?32,?3)?(128,?10)
Epoch?1/5
391/391?[==============================]?-?7s?19ms/step?-?loss:?1.7276?-?accuracy:?0.3358?-?val_loss:?1.5801?-?val_accuracy:?0.4427
Epoch?2/5
391/391?[==============================]?-?7s?18ms/step?-?loss:?1.5045?-?accuracy:?0.4606?-?val_loss:?1.4808?-?val_accuracy:?0.4812
Epoch?3/5
391/391?[==============================]?-?6s?17ms/step?-?loss:?1.3919?-?accuracy:?0.5019?-?val_loss:?1.4596?-?val_accuracy:?0.4921
Epoch?4/5
391/391?[==============================]?-?7s?18ms/step?-?loss:?1.3039?-?accuracy:?0.5364?-?val_loss:?1.4651?-?val_accuracy:?0.4950
Epoch?5/5
391/391?[==============================]?-?6s?16ms/step?-?loss:?1.2270?-?accuracy:?0.5622?-?val_loss:?1.4483?-?val_accuracy:?0.5030
79/79?[==============================]?-?1s?11ms/step?-?loss:?1.4483?-?accuracy:?0.5030
saved?to?ckpt/weights.ckpt
Epoch?1/5
391/391?[==============================]?-?7s?19ms/step?-?loss:?1.7216?-?val_loss:?1.5773
Epoch?2/5
391/391?[==============================]?-?10s?26ms/step?-?loss:?1.5010?-?val_loss:?1.5111
Epoch?3/5
391/391?[==============================]?-?8s?21ms/step?-?loss:?1.3868?-?val_loss:?1.4657
Epoch?4/5
391/391?[==============================]?-?8s?20ms/step?-?loss:?1.3021?-?val_loss:?1.4586
Epoch?5/5
391/391?[==============================]?-?7s?17ms/step?-?loss:?1.2276?-?val_loss:?1.4583
loaded?weights?from?file.
79/79?[==============================]?-?1s?12ms/step?-?loss:?1.4483





1.4482733222502697

查看更多关于CIFAR10自定义网络实战的详细内容...

  阅读:37次

上一篇: 模型加载与保存

下一篇:什么是卷积