generator和discriminator都使用了简单的全连接神经网络

from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.layers import Dense, Dropout, Input
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.optimizer_v2.adam import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt


# Load the dataset
def load_data():
    (x_train, y_train), (_, _) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5) / 127.5

    # Convert shape from (60000, 28, 28) to (60000, 784)
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train)


X_train, y_train = load_data()
print(X_train.shape, y_train.shape)


def build_generator():
    model = Sequential()

    model.add(Dense(units=256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(units=512))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(units=1024))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(units=784, activation='tanh'))

    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return model


generator = build_generator()
generator.summary()


def build_discriminator():
    model = Sequential()

    model.add(Dense(units=1024, input_dim=784))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Dense(units=512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Dense(units=256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Dense(units=1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return model


discriminator = build_discriminator()
discriminator.summary()


def build_GAN(discriminator, generator):
    discriminator.trainable = False
    GAN_input = Input(shape=(100,))
    x = generator(GAN_input)
    GAN_output = discriminator(x)
    GAN = Model(inputs=GAN_input, outputs=GAN_output)
    GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return GAN


GAN = build_GAN(discriminator, generator)
GAN.summary()


def draw_images(generator, epoch, examples=25, dim=(5, 5), figsize=(10, 10)):
    noise = np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(25, 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('img/Generated_images %d.png' % epoch)


def train_GAN(epochs=1, batch_size=128):
    # Loading the data
    X_train, y_train = load_data()

    # Creating GAN
    generator = build_generator()
    discriminator = build_discriminator()
    GAN = build_GAN(discriminator, generator)

    for i in range(1, epochs + 1):
        print("Epoch %d" % i)

        for _ in tqdm(range(batch_size)):
            # Generate fake images from random noiset
            noise = np.random.normal(0, 1, (batch_size, 100))
            fake_images = generator.predict(noise)

            # Select a random batch of real images from MNIST
            real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]

            # Labels for fake and real images
            label_fake = np.zeros(batch_size)
            label_real = np.ones(batch_size)

            # Concatenate fake and real images
            X = np.concatenate([fake_images, real_images])
            y = np.concatenate([label_fake, label_real])

            # Train the discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y)

            # Train the generator/chained GAN model (with frozen weights in discriminator)
            discriminator.trainable = False
            GAN.train_on_batch(noise, label_real)

        # Draw generated images every 15 epoches
        if i == 1 or i % 10 == 0:
            draw_images(generator, i)


train_GAN(epochs=500, batch_size=128)

训练1次:
Generated_images 1.png

训练100次:
Generated_images 100.png

训练200次:
Generated_images 200.png

训练300次:
Generated_images 300.png

训练400次:
Generated_images 400.png

训练500次:
Generated_images 500.png

这里这是用了简单的全连接神经网络,使用其他网络模型应该可以得到更好的结果。

最后修改:2021 年 10 月 07 日
如果觉得我的文章对你有用,请随意赞赏