참고 : https://www.youtube.com/playlist?list=PLlMkM4tgfjnLSOjrEJN31gZATbcj_MpUm
모두를 위한 딥러닝 강좌 시즌 1
www.youtube.com
이번에는 저번에 배운걸 tensorflow로 구현해보는 시간
원래는 Wx 뒤에 +b 가 있을 수 있지만 여기선 생략하고 생각해보자
범위는 -30에서 50으로 변경하면서 결과를 본다
plt 라는 거는 import 하면 그래프를 그릴 수 있다.
저런식으로! 그래프 모양이 나타나겠지.
cost를 최소화 하는 W는 1이겠다.
그래프를 더 자세히 살펴보자.
저번 시간에 봤던 것 처럼
기울기가 양수 일때에는 w를 -방향으로 이동하면서 cost 값을 살펴보고
기울기가 음수 일때에는 w를 +방향으로 이동하면서 cost 값을 살펴본다.
->이제 이거를 tensorflow로 구현하려고 한다.
그냥 수식 그대로 코드를 짜주고
마지막에 assign을 해야하는데 tensorflow의 경우는 다른 언어들처럼 = 로 그냥 대입하는 형식으로 업데이트를 하지 못한다.
그래서 저런 .assign() 함수를 통해 업데이트를 해주기로 한다.
update를 바로 실행시키면 일련의 동작들이 일어나겠지!
전체 코드를 보자.
마지막에 update를 실행시킨다(x와 y의 데이터를 던져주면서!)
실행을 시키면
이런식으로 나오는데 cost는 점점 작아지고
W는 처음엔 랜덤인 수지만 1에 가까운 수가 된다.
잘 돌아가는구만
그런데 매번 이렇게 미분값을 할 수는 있지만
이렇게 하지말고
optimizer를 선언해서 cost 를 굳이 미분하지 않아서 tensorflow가 알아서 해준다.
둘이 같은 의미이다!
그래서 optimizer 를 적용해서 다시 코드를 써보면
처음에 W에 말이 안되는 값을 넣어보자. 5.0이라는 값!
이제 5에서 잘 내려가는지를 봐야겠지.
실행을 시키면 5.0이지만, 학습을 할수록 빠르게 내려가서 1.0이 된다.
저 optimizer 한 줄을 썼다고 몇줄만에 1.0을 찾아낸다.
이번에는 5.0말고 -3.0을 줬다.
이번엔 시작점이 왼쪽에 있겠지. 기울기가 음수인 곳에,
이 경우에도 마찬가지고 1.0에 가까워진다.
optimizer가 잘 동작하는구나~
optional!
tensorflow가 주는 gradient 값을 좀 조정하고 싶을 때가 있을지도 모른다.
optimizer 부분까지는 이전과 동일하다.
그런데 optimizer에서 minimize하라고 바로 하는게 아니라
optimizer에서 gradient를 계산해 달라고 하는거다. 이 계산한 값은 gvs 에 넣어보자.
이 값을 원하는대로 수정할 수 있는거다.
수정이 끝나면 다시 apply_gradient를 해서 optimizer에 다시 apply를 하는 거다.
추가로, 자동으로 계산된 gradient랑 우리가 수식적으로 계산한 gradient(미분한거) 게 같은값인지 확인해보자.
돌려서 gradient(수식으로 세워서 구한거), W, gvs(자동으로 계산된 거)를 출력하게 하면
보면 동그라미 2번이랑 동그라미 4번의 gradient 값이 동일한 걸 확인할 수 있다~
import tensorflow.compat.v1 as tf
X = [1,2,3]
Y = [1,2,3]
W = tf.Variable(5.0)
hypothesis = X*W
cost = tf.reduce_mean(tf.square(hypothesis-Y))
#minimize
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.1)
train=optimizer.minimize(cost)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for step in range (100):
print(step, sess.run(W))
sess.run(train)
이걸 돌려보는데보는데
RuntimeError: `loss` passed to Optimizer.compute_gradients should be a function when eager execution is enabled.
오류가 뜬다
음 그렇군
tf.disable_v2_behavior()
이거 추가하니까 잘 돌아간다.
오류해결!