Inflearn brand logo image

인프런 커뮤니티 질문&답변

sin님의 프로필 이미지
sin

작성한 질문수

유니티 머신러닝 에이전트 완전정복 (기초편)

A2C 실습 2

dqn 인덱스 에러

작성

·

33

0

dqn 학습 부분에서, index에러가 자꾸 납니다..

 

dqn.py는 github에 올려져 있는 것을 그대로 사용했습니다.

 

< 에러 코드 >

(colab) C:\Users\pss60\Desktop\ML-Agents_Project\agents>python dqn.py
Traceback (most recent call last):
  File "C:\Users\pss60\Desktop\ML-Agents_Project\agents\dqn.py", line 218, in <module>
    state = preprocess(dec.obs[OBS], dec.obs[GOAL_OBS])
IndexError: list index out of range

 

< GridAgent >

using System;
using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Rendering;
using UnityEngine.Serialization;
using System.Collections.Generic;

public class GridAgent : Agent
{
    [FormerlySerializedAs("m_Area")]
    [Header("Specific to GridWorld")]
    public GridArea area;
    public float timeBetweenDecisionsAtInference;
    float m_TimeSinceDecision;

    [Tooltip("Because we want an observation right before making a decision, we can force " +
        "a camera to render before making a decision. Place the agentCam here if using " +
        "RenderTexture as observations.")]
    public Camera renderCamera;

    VectorSensorComponent m_GoalSensor;

    private Transform agentTrans = null;

    private Vector3 moveRight = new Vector3(1, 0, 0);
    private Vector3 moveLeft = new Vector3(-1, 0, 0);
    private Vector3 moveUp = new Vector3(0, 0, 1);
    private Vector3 moveDown = new Vector3(0, 0, -1);

    public enum GridGoal
    {
        GreenPlus,
        RedEx,
    }

    // Visual representations of the agent. Both are blue on top, but different colors on the bottom - this
    // allows the user to see which corresponds to the current goal, but it's not visible to the camera.
    // Only one is active at a time.
    public GameObject GreenBottom;
    public GameObject RedBottom;

    GridGoal m_CurrentGoal;

    public GridGoal CurrentGoal
    {
        get { return m_CurrentGoal; }
        set
        {
            switch (value)
            {
                case GridGoal.GreenPlus:
                    GreenBottom.SetActive(true);
                    RedBottom.SetActive(false);
                    break;
                case GridGoal.RedEx:
                    GreenBottom.SetActive(false);
                    RedBottom.SetActive(true);
                    break;
            }
            m_CurrentGoal = value;
        }
    }

    [Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
        "masking turned on may not behave optimally when action masking is turned off.")]
    public bool maskActions = true;

    const int k_NoAction = 0;  // do nothing!
    const int k_Up = 1;
    const int k_Down = 2;
    const int k_Left = 3;
    const int k_Right = 4;

    EnvironmentParameters m_ResetParams;

    public override void Initialize()
    {
        m_GoalSensor = this.GetComponent<VectorSensorComponent>();
        m_ResetParams = Academy.Instance.EnvironmentParameters;
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        Array values = Enum.GetValues(typeof(GridGoal));
        int goalNum = (int)CurrentGoal;
        

        // 현재 에이전트의 x, y 좌표 위치 값을 관측 정보에 추가
        sensor.AddObservation(agentTrans.position.x);
        sensor.AddObservation(agentTrans.position.z);

        // 각각 도형에 대한 좌표 위치 값을 관측 정보에 추가
        List<int> otherPos = area.otherPos;
        for (int i = 0; i < otherPos.Count; i++)
            sensor.AddObservation(otherPos[i]);

        // 목표 지점에 대한 정보
        m_GoalSensor.GetSensor().AddOneHotObservation(goalNum, values.Length);
    }

    public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
    {
        // Mask the necessary actions if selected by the user.
        if (maskActions)
        {
            // Prevents the agent from picking an action that would make it collide with a wall
            var positionX = (int)agentTrans.localPosition.x;
            var positionZ = (int)agentTrans.localPosition.z;
            var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;

            if (positionX == 0)
            {
                actionMask.SetActionEnabled(0, k_Left, false);
            }

            if (positionX == maxPosition)
            {
                actionMask.SetActionEnabled(0, k_Right, false);
            }

            if (positionZ == 0)
            {
                actionMask.SetActionEnabled(0, k_Down, false);
            }

            if (positionZ == maxPosition)
            {
                actionMask.SetActionEnabled(0, k_Up, false);
            }
        }
    }

    // to be implemented by the developer
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // 매 행동마다 -0.01 보상(패널티) 부여
        AddReward(-0.01f);
        var action = actionBuffers.DiscreteActions[0];

        // 에이전트가 이동하게 될 위치 값을 저장할 변수 (Vector3)
        var targetPos = agentTrans.position;

        switch (action)
        {
            case k_NoAction:
                // do nothing
                break;
            case k_Right: // 부여받은 행동이 오른쪽이라면,
                targetPos = agentTrans.position + moveRight; // 현재 위치에서 1만큼 x축 방향으로 설정
                break;
            case k_Left:
                targetPos = agentTrans.position + moveLeft;
                break;
            case k_Up:
                targetPos = agentTrans.position + moveUp;
                break;
            case k_Down:
                targetPos = agentTrans.position + moveDown;
                break;
            default:
                throw new ArgumentException("Invalid action value");
        }

        var hit = Physics.OverlapBox(
            targetPos, new Vector3(0.3f, 0.3f, 0.3f));

        // 벽에 부딪히지 않았다면,
        if (hit.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
        {
            // 정해진 위치로 이동
            agentTrans.position = targetPos;

            // +오브젝트 만났다면, 
            if (hit.Where(col => col.gameObject.CompareTag("plus")).ToArray().Length == 1)
            {
                ProvideReward(GridGoal.GreenPlus);
                EndEpisode();
            }
            // x오브젝트 만났다면,
            else if (hit.Where(col => col.gameObject.CompareTag("ex")).ToArray().Length == 1)
            {
                ProvideReward(GridGoal.RedEx);
                EndEpisode();
            }
        }
    }

    private void ProvideReward(GridGoal hitObject)
    {
        if (CurrentGoal == hitObject)
        {
            SetReward(1f);
        }
        else
        {
            SetReward(-1f);
        }
    }

    // WASD 키보드로 에이전트 이동
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        discreteActionsOut[0] = k_NoAction;
        if (Input.GetKey(KeyCode.D))
        {
            discreteActionsOut[0] = k_Right;
        }
        if (Input.GetKey(KeyCode.W))
        {
            discreteActionsOut[0] = k_Up;
        }
        if (Input.GetKey(KeyCode.A))
        {
            discreteActionsOut[0] = k_Left;
        }
        if (Input.GetKey(KeyCode.S))
        {
            discreteActionsOut[0] = k_Down;
        }
    }

    // to be implemented by the developer
    public override void OnEpisodeBegin()
    {
        area.AreaReset();
        Array values = Enum.GetValues(typeof(GridGoal));
        CurrentGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
    }

    public void FixedUpdate()
    {
        WaitTimeInference();
    }

    void WaitTimeInference()
    {
        if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null)
        {
            renderCamera.Render();
        }

        if (Academy.Instance.IsCommunicatorOn)
        {
            RequestDecision();
        }
        else
        {
            if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference)
            {
                m_TimeSinceDecision = 0f;
                RequestDecision();
            }
            else
            {
                m_TimeSinceDecision += Time.fixedDeltaTime;
            }
        }
    }
}

답변 3

0

민규식님의 프로필 이미지
민규식
지식공유자

그리고 강의는 A2C 실습 부분인듯 한데 혹시 OBS가 Vector obs 일까요?

0

민규식님의 프로필 이미지
민규식
지식공유자

안녕하세요!

여기서 인덱스 에러가 발생하는 것 보면 골 정보가 obs에 포함되어 있지 않거나 뭔가 문제가 있어보이는데요...!

혹시 유니티나 ML-Agents는 어떤 버전을 사용하고 계실까요? 버전에 따라 이런 에러가 발생할 수 있고 특히 최신 ML-Agents 버전에서 시각적 관측 사용할 때 에러가 잘 난다고 알고있어서 확인해보면 좋을 것 같습니다!

0

안녕하세요, 인프런 AI 인턴입니다. sin님, DQN 학습 과정 중 발생한 인덱스 오류와 관련하여 도움이 될 수 있는 자료를 안내해 드립니다.

현재 설명하신 문제는 dec.obs[OBS]dec.obs[GOAL_OBS]에서 인덱스 범위를 넘어서는 접근을 시도하면서 발생하는 것으로 보입니다. 이는 주로 관측 데이터가 예상보다 적을 때 발생할 수 있으며, 이와 관련된 기존의 사례들에서 유사한 해결책을 찾을 수 있습니다.

유사한 문제를 다룬 기존의 게시물 중 하나인 DQN 환경 학습에서는 입력 데이터의 차원 불일치 문제에 대해 논의하고 있습니다. 주로 입력 데이터의 구조, 특히 관측치의 개수와 관련해 확인할 필요가 있습니다. 이 문제의 해결책은 주로 유니티 환경 내 설정 확인과, 전달된 관측 데이터의 크기 조정에 달려 있습니다.

  • 체크포인트:

    • Unity 환경 내의 시각적 관측(sensor) 설정을 확인하거나, 모든 관측 데이터가 Python 측에 올바르게 전달되고 있는지 확인하세요.
  • 디버깅 팁:

    • 코드 중 dec.obs의 길이를 print로 출력하여 기대하는 길이와 실제로 넘어오는 길이를 비교해 보세요.
    • 필요시, 인덱스에 따른 로직을 수정하여 문제가 발생하지 않도록 조치할 수 있습니다.

문제가 지속될 경우, 강의자분이나 유니티 포럼을 통하여 직접적인 지침을 받는 것도 좋습니다.

저는 질문자님의 이해를 돕고자 지식공유자님의 기존 답변을 참고하여 유사한 답변 링크를 추천해드리고 있습니다.
현재 베타 기능으로 답변이 만족스럽지 않을 수 있는 점 양해 부탁드립니다. 🙏
추가적으로 궁금한 점이 있으시면, 이어서 질문해 주세요. 곧 지식공유자께서 답변해 주실 것입니다.

sin님의 프로필 이미지
sin

작성한 질문수

질문하기