function out = gw( arg, arg2) % gw: a grid world global nx ny ns na move loss global start goal Rew Val Elig global alpha beta gamma delta lambda global state action value reward t tmax stop global ff Mov switch( arg) % % initialization % case 'new' % Setup everything gw( 'world'); gw( 'agent'); gw( 'init'); case 'world' % A new world % states nx = 7; ny = 7; ns = nx*ny; % actions na = 5; % R,U,L,D,stay move = [1,0; 0,1; -1,0; 0,-1; 0,0]; loss = -0.1*[ 1; 1; 1; 1; 0]; % reward field Rew = zeros( nx, ny); Rew(nx,ny) = 1; Rew(1,ny) = -1; Rew((1+nx)/2,(1+ny)/2) = -1; Rew(nx,1) = -1; % start and goal start = [1,1]; goal = [nx,ny]; gwf( 'init'); gwf( 'reward'); case 'agent' % A new agent Val = zeros(nx,ny); alpha = 0.5; % learning rate beta = 2; % softmax action selection gamma = 0.9; % discount factor lambda = 0.5; % eligibility trace gwf( 'value', Val); case 'init' % A new trial tmax = 100; state = zeros(tmax,2); action = zeros(tmax,1); reward = zeros(tmax,1); value = zeros(tmax,1); delta = zeros(tmax,1); Elig = zeros(nx,ny); gwf( 'agent', [start,Rew(start(1),start(2))]); drawnow; % % run % case 'run' stop = 0; st = start; for t=1:tmax if stop==1, break; end % new state state(t,:) = st; gwf( 'agent', [st,Rew(st(1),st(2))]); figure(ff); drawnow; % value value(t) = Val(st(1),st(2)); if t > 1 % TD error delta(t-1) = reward(t-1) + gamma*value(t) - value(t-1); % update value Val = Val + alpha*delta(t-1)*Elig; end % update eligibility trace Elig = gamma*lambda*Elig; Elig(st(1),st(2)) = 1; % Elig(st(1),st(2)) = Elig(st(1),st(2)) + (1-gamma*lambda); % final step if state(t,:)==goal reward(t) = Rew(st(1),st(2)); delta(t) = reward(t) - value(t); Val = Val + alpha*delta(t)*Elig; break; end % predict next states: each row for an action pstate = repmat( st, na, 1) + move; pstate = min( max( pstate,1), repmat([nx,ny],na,1)); % linear index istate = sub2ind( [nx,ny], pstate(:,1), pstate(:,2)); % take an action by softmax pq = loss + gamma*Val(istate); % each row for an action prob = exp( beta*pq); prob = prob./(sum(prob)); % selection probablity act = find( cumsum(prob) > rand(1)); action(t) = act(1); % index of selected action % reward: from state and action reward(t) = Rew(st(1),st(2)) + loss(action(t)); % disp([t,state(t,:),action(t),reward(t),value(t),delta(t)]); % next state st = st + move(action(t),:); st = min( max( st, 1), [nx,ny]); end disp( [t, mean( reward(1:t))]); % mean reward case 'try' gw( 'init'); gw( 'run'); gwf( 'wave', t); gwf( 'value'); % gwf( 'traj', state(1:t,:)); case 'stop' stop = 1 % % movie making % case 'movie' gwf( 'init', [320 240]); gwf( 'reward'); gwf( 'agent', [state(1,:),Rew(state(1,1),state(1,2))]); tf = t; figure(1); Mov = moviein( tf); % set(gca,'nextplot','replacechildren'); disp( ' recording a movie...'); for t = 1:tf gwf( 'agent', [state(t,:),Rew(state(t,1),state(t,2))]); drawnow; Mov(:,t) = getframe; end disp( ' replaying...'); movie( Mov); if nargin < 2, arg2 = 'gw.mov'; end disp( [' saving to ' arg2]); qtwrite( Mov, get(1,'ColorMap'), arg2); % % Misceraneous % case 'save' if nargin < 2, arg2 = 'gw.mat', end save( arg2); case 'load' whos( '-file', arg2); load( arg2); otherwise if isnumeric(arg), n = arg; % repeat arg times elseif ischar(arg), n = str2num(arg); else n = 0; end if n>0 & nargin>=2 for i = 1:n gw( arg2); if stop, break; end end else error( [ ' invalid command ', arg]); end end %%%%%%%%