multiple inputs
327
작성한 질문수 20
예시에서는 독립변수가 하나밖에 존재하지 않은데 2개 이상일때는 call 메서드에 어떻게 적용시키나요? 참고로 from_tensor_slices로 dataset을 만들었을 때 궁금합니다.
아래의 코드는 제가 만든 모델 코드의 일부분 입니다.
다음과 같이 from_tensor_slices로 총 7개의 변수를 가진 tf.dataset을 생성했습니다.
이후 main model의 call 메서드에 인자를 다음과 같이 6개를 지정해서 model을 학습시켰습니다. 이렇게 하는게 맞는지 궁금합니다.
train_ds = tf.data.Dataset.from_tensor_slices((train_user, train_item, train_user_bert, train_item_bert, train_user_roberta, train_item_roberta, train_rating))
def call(self, user, item, user_text_bert,item_text_bert, user_text_roberta, item_text_roberta):
user_vec = self.user_emb(user)
item_vec = self.item_emb(item)
for epoch in range(Epochs):
for a, b, c, d, e, f, g in train_ds:
with tf.GradientTape() as tape:
predictions = model(a, b, c, d, e, f)
loss = loss_object(g, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
답변 1
0
안녕하세요, 인프런 AI 인턴이에요.
여러 개의 독립 변수를 가진 경우, 각 변수마다 모델에 필요한 인자를 전달해야 합니다. 예를 들어, 7개의 변수를 가진 경우, call 메서드에서 각 변수에 대한 인자를 따로 전달해야 합니다.
아래는 예시 코드입니다.
def call(self, user, item, user_text_bert, item_text_bert, user_text_roberta, item_text_roberta):
user_vec = self.user_emb(user)
item_vec = self.item_emb(item)
# 나머지 변수에 대해서도 필요한 전처리 및 연산 수행
# 예시로 다음과 같이 출력해 볼 수 있습니다.
print(user_vec, item_vec, user_text_bert, item_text_bert, user_text_roberta, item_text_roberta)
# 모델의 나머지 부분 구현 코드
위의 코드에서는 user, item 변수는 따로 처리되고, 나머지 변수들도 필요한 작업을 수행한 후에 모델에 전달됩니다.
그리고 학습하는 코드 부분에서도 인자를 잘 전달하셨습니다.
for a, b, c, d, e, f, g in train_ds:
with tf.GradientTape() as tape:
predictions = model(a, b, c, d, e, f)
loss = loss_object(g, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
맞게 구현하셨습니다.
각 변수의 전처리 및 모델에 필요한 연산을 잘 적용하시고 사용하시면 됩니다. 추가적인 질문이 있으시면 언제든지 물어보세요. 즐거운 학습 되세요!
[긴급요청2] 28강 동영상 오류(27강 음성 재생) 수정 예정일자를 알려주세요
0
10
1
[긴급요청] 28강 동영상 전체에 오디오가 잘못들어가있습니다.
0
25
1
17강 강의 마지막 부분에서 헷갈리는 부분이 있습니다.
0
24
0
14강의 예시 데이터를 통한 답을 구하는 과정에서의 문의
0
26
1
강의 자료 열람 요청이 안되네요
0
18
1
예측에 사용하는 경제지표의 활용에 대해
1
32
2
강의 자료를 어디서 확인 할 수 있나요?
1
44
3
실습 코드
0
35
2
MMDetection 버전 이슈
0
59
2
[업로드오류] 강의 내용과 제목 불일치
0
46
2
섹션22 퀴즈 질문
0
59
1
import torch가 안되는 경우는 어떻게 하나요?
0
57
1
소리가 겹쳐서 들려요
0
66
2
20강에서 파인튜닝 때 사용한 데이터가 없어졌습니다. LoRA Trainer 매개변수도 라이브러리 업그레이드로 수정되었습니다.
0
46
1
[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?
0
64
1
수업자료
0
43
2
동영상 재생오류
0
62
1
epoch 1부터 loss가 너무 낮게 나와 학습이 안되네요
0
552
1
텍스트 데이터일 때의 dtype
0
275
1
12강 data split take와 skip
0
326
1
<tensorflow사용메뉴얼> 강의파일
0
296
0
SyntaxError: keyword can't be an expression
0
522
0
강의 감사합니다. 다섯번째 강의인 Model Implementation가 재생이 안됩니다.
0
176
0
unsupported operand type(s) for *: 'float' and 'NoneType'
0
1974
1





