[Tensorflow] tf.cond()로 조건에 따른 함수 반환
by Heejin Do
Tensorflow에서 tf.cond()
를 이용해 조건의 True/False 여부에 따라 다른 함수를 반환하는 방법을 소개한다.
tf.cond()
는 pred가 참일 때 treu_fn을 반환하고 거짓일 때 false_fn을 반환한다.
tensorflow.cond( pred, true_fn=None, false_fn=None, name=None )
예를 들어, 아래와 같이 a >= b라는 조건을 pred로 지정하고
참일 경우 tf.multiply(a, 2)를 반환하도록, 거짓일 경우 tf.multiply(b, 2)를 반환하도록 하면 result 값은 14를 가진다.
import tensorflow as tf
a = 3
b = 7
result = tf.cond(a >= b, lambda: tf.multiply(a, 2), lambda: tf.multiply(b, 2))
print(result)
# tf.Tensor(14, shape=(), dtype=int32)
아래처럼 함수를 미리 정의하고 사용할 수도 있다.
import tensorflow as tf
a = 3
b = 7
def f1(): return tf.multiply(a, 2)
def f2(): return tf.multiply(b, 2)
result = tf.cond(a >= b, f1, f2)
print(result)
# tf.Tensor(14, shape=(), dtype=int32)
결과값은 동일하다.
Subscribe via RSS