Deep Learning Frameworks

다양한 딥러닝 프레임워크 중에서 선택하는 팁

- 프로그래밍하기 편한 것

- 빠른 것

- 오픈소스인 것

 

 

TensorFlow

아주 간단한 cost function을 상정해보자

 

$J(w)  = w^{2} - 10w + 25$

 

import!!

import numpy as np
import tensorflow as tf

 

define the parameter w

w = tf.Variable(0, dtype=tf.float32) # variable to optimize
optimizer = tf.keras.optimizers.Adam(0.1) # learning rate 0.1

def train_step():
    with tf.GradientTape() as tape:
        cost = w ** 2 - 10 * w + 25
        # only have to write foward prop then tf can do the back prop
        # GradientTape : record the seqeunce of operations in forward prop, 
        # then tape backwards for back prop
    trainable_varibales = [w]
    grads = tape.gradient(cost, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))
    
print(w)

 

one step of learning

train_step()
print(w)

살짝 상승

 

for i in range(1000):
	train_step()
print(w)

$w$가 $5$에 가까워짐

 

----

위에서는 변수가 $w$ 뿐이었지만, 파라미터 $w$ 뿐만 아니라 나의 데이터에 따라 바뀌는 cost function 일 때는 어떻게 할까?

 

training data $x$가 있다고 할 때,

w = tf.Variable(0, dtype = tf.float32)
x = np.array([1.0, -10.0, 25.0], dtype=np.float32)
optimizer = tf.keras.optimizers.Adam(0.1)

def training(x, w, optimizer):
    def cost_fn():
        return cost = x[0] * w ** 2 + x[1] * w + x[2]
        # array x controls the coefficients of the cost function
    for i in range(1000):
    	optimizer.minimize(cost_fn, [w])
    return w
    
w = training(x, w, optimizer)
    
optimizer.minimize(cost_fn, [w])
# 한번 학습
# optimizer.minimize가 아래 코드와 동일
# trainable_varibales = [w]
# grads = tape.gradient(cost, trainable_variables)
# optimizer.apply_gradients(zip(grads, trainable_variables))

학습 완료

 

def cost_fn():
	return x[0] * w ** 2 + x[1] * w + x[2]

텐서플로우로 하여금 computation graph 를 구성하게 함

 

'인공지능 > DLS' 카테고리의 다른 글

[3.1.] Setting Up your Goal  (0) 2022.07.17
[3.1.] Introduction to ML Strategy  (0) 2022.07.17
[2.3.] Multi-class Classification  (0) 2022.07.13
[2.3.] Batch Normalization  (0) 2022.07.12
[2.3.] Hyperparameter Tuning  (0) 2022.07.12
복사했습니다!