• 카테고리

    질문 & 답변
  • 세부 분야

    딥러닝 · 머신러닝

  • 해결 여부

    해결됨

cross_val_score에 대한 질문이 있습니다.

20.06.14 20:58 작성 조회수 477

0

dt_clf=DecisionTreeClassifier()

score=cross_val_score(dt_clf,X,y,cv=5)

로 돌리고

결과가 궁금한 데이터를 갖고 와서 돌리면

pred=dt_clf.predict(b)

라고 돌리면 

NotFittedError: This DecisionTreeClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

이런 에러가 호출되더라구요

cross val score가 핏 프레딕트 kfold가 한번에 되는거라고 하셨는데 

새로운 데이터를 학습시킨것에 적용해보려면 fit으로 한번 더 돌려야 하는건가요?

답변 4

·

답변을 작성해보세요.

0

gkgktmd님의 프로필

gkgktmd

질문자

2020.06.17

감사합니다!

0

kfold로 하셔도 되고, GridSearchCV 를 이용하실 수도 있습니다.

GridSearchCV로 최적화 한 후에 best_estimator_ 속성으로 해당 Estimator를 반환받을 수 있습니다.

감사합니다.

0

gkgktmd님의 프로필

gkgktmd

질문자

2020.06.16

아 그렇다면 cross는 사실상 정확도를 빠르게 확인하기 위함이고

학습된 정보를 기반으로 실데이테를 예측해보는건 kfold로 해야된다는 말씀이실까요?

(그냥 스플릿 트레인 테스트말고 5번 나눠서 하는 형식으로 하고 싶은 경우)

0

안녕하십니까,

빨리 답변을 달아드린다는게 깜박 잊고 답변이 늦었습니다.

cross_val_score(dt_clf,X,y,cv=5) API로 인자로 들어가는 dt_clf가 call by reference 일 텐데, 내부 함수에서 dt_clf를 복사하는 것 같습니다. 때문에 API에서 복사된 dt_clf 를 5번 학습을 할텐데, 실제로 dt_clf는 학습된 객체가 아닙니다. 이게 API 설계를 그런식으로 한것 같습니다.

요약하자면 cross_val_score( ) 인자로 dt_clf 를 넣더라도 dt_clf 자체가 학습이 되지는 않고 dt_clf를 복제한 object 를 학습/검증 후 평균 메트릭을 반환합니다.

감사합니다.