Gobin님 자료
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
total_epoch = 1000
batch_size = 100
n_input = 28 * 28
n_noise = 128
n_class = 10
X = tf.placeholder(tf.float32, [None, n_input])
Y = tf.placeholder(tf.float32, [None, n_class])
Z = tf.placeholder(tf.float32, [None, n_noise])
def generator(noise, labels, reuse=False):
with tf.variable_scope('generator'):
inputs = tf.concat(1, [noise, labels])
G1 = tf.contrib.layers.fully_connected(inputs, 256)
G2 = tf.contrib.layers.fully_connected(G1, n_input)
return G2
def discriminator(inputs, labels, reuse=None):
with tf.variable_scope('discriminator') as scope:
if reuse:
scope.reuse_variables()
inputs = tf.concat(1, [inputs, labels])
D1 = tf.contrib.layers.fully_connected(inputs, 256)
D2 = tf.contrib.layers.fully_connected(D1, 256)
D3 = tf.contrib.layers.fully_connected(D2, 1, activation_fn=None)
return D3
G = generator(Z, Y)
D_real = discriminator(X, Y)
D_gene = discriminator(G, Y, True)
loss_D_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(D_real, tf.ones_like(D_real)))
loss_D_gene = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(D_gene, tf.zeros_like(D_gene)))
loss_D = loss_D_real + loss_D_gene
loss_G = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(D_gene, tf.ones_like(D_gene)))
vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
train_D = tf.train.AdamOptimizer().minimize(loss_D, var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G, var_list=vars_G)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
total_batch = int(mnist.train.num_examples/batch_size)
loss_val_D, loss_val_G = 0, 0
for epoch in range(total_epoch):
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
noise = np.random.uniform(-1., 1., size=[batch_size, n_noise])
_, loss_val_D = sess.run([train_D, loss_D], feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
_, loss_val_G = sess.run([train_G, loss_G], feed_dict={Y: batch_ys, Z: noise})
print 'Epoch:', '%04d' % (epoch + 1), \
'D loss: {:.4}'.format(loss_val_D), \
'G loss: {:.4}'.format(loss_val_G)
if epoch % 10 == 0:
noise = np.random.uniform(-1., 1., size=[30, n_noise])
samples = sess.run(G, feed_dict={Y: mnist.validation.labels[:30], Z: noise})
fig, ax = plt.subplots(6, n_class, figsize=(n_class, 6))
for i in range(n_class):
for j in range(6):
ax[j][i].set_axis_off()
for j in range(3):
ax[0+(j*2)][i].imshow(np.reshape(mnist.validation.images[i+(j*n_class)], (28, 28)))
ax[1+(j*2)][i].imshow(np.reshape(samples[i+(j*n_class)], (28, 28)))
plt.savefig('samples2/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
plt.close(fig)
print '최적화 완료!'