TensorFlow 1.x CNN Tutorial
2. 모델 생성
3. 실행 하기 (그래프모드) session.run
import tensorflow as tf
input1 = tf.constant([3.0])
input2 = tf.constant([2.0])
input3 = tf.constant([5.0])
intermed12 = tf.add(input1, input2)
intermed23 = tf.add(input2, input3)
ops = {'a': intermed12,
'b': intermed12,
}
with tf.Session() as sess:
result1, resutl2 = sess.run([ops['a'],ops['b']] )
print(result1)
`
Dynamic input with feed_dic
input1 = tf.placeholder(tf.float32) # feed_dict={input1:[3.3]
input2 = tf.placeholder(tf.float32) # feed_dict={input2:[3.3]
input3 = tf.placeholder(tf.float32) # feed_dict={input3:[3.3]
intermed12 = tf.add(input1, input2)
intermed23 = tf.add(input2, input3)
ops = {'a': intermed12,
'b': intermed23,
}
with tf.Session() as sess:
result1, resutl2 = sess.run([ops['a'],ops['b']], feed_dict={input1:[3.3],input2:[2.2],input3:[5.5]} )
print(result1)
4. 저장 & 복원
sess = tf.Session()
saver = tf.train.Saver() # 위와 순서 바뀌면 `No variables to save`에러 출력
#saver.restore(sess, restore_model_path) #복원
sess.run(...)
save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) #저장
5. 테스트
Placeholder
일종의 자료형, 다른 텐서를 할당하는 것
placeholder의 전달 파라미터는 다음과 같다.
placeholder(
dtype, # 데이터 타입을 의미하며 반드시 적어주어야 한다.
shape=None, # 입력 데이터의 형태를 의미한다. 상수 값이 될 수도 있고 다차원 배열의 정보가 들어올 수도 있다. ( 디폴트 파라미터로 None 지정 )
name=None # 해당 placeholder의 이름을 부여하는 것으로 적지 않아도 된다. ( 디폴트 파라미터로 None 지정 )
)
tf.get_variable()과 tf.get_collection()
- tf.get_variable() : 텐서의 저장 공간의 주 형태인 variable을 선언하는 방법
tf.Variable()
가 원래 선언 방법. 하지만,tf.get_variable()
를 사용하는 것이 좀 더 범용적- 이유 :
get_variable()
은 정의된 name filed값과 동일한 텐서가 존재할 경우, 새로 만들지 않고 기존 텐서를 불러들인다.
def get_variable(name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None, # variable의 소속
caching_device=None,
partitioner=None,
validate_shape=True,
custom_getter=None):
# 출처: https://eyeofneedle.tistory.com/24 [Technology worth spreading
- tf.get_collection() : collection은 variable의 소속
- 목적은 해당 variable을 코드의 다른 위치에서 불러오기 위해서
- tf.get_collection(key)가 실행되면, key의 collection에 속하는 variable들의 리스트가 리턴
- tf.get_collection()사용법 7가지 [자세히]