function out = dig( arg, arg2) % dig: dig for the big reward global ns na rew1 rew2 global R V Q nt averew entval gamman global alpha beta gamma delta lambda global state action value reward t tmax stop vals if( nargin < 1) % default ns = 4; rew2 = 6; % bonus gamma = 0.9; % discount factor dig( 'new'); return; end switch( arg) % % initialization % case 'new' % Setup everything % states if nargin>=2, ns=arg2, end dig( 'world'); dig( 'agent'); dig( 'init'); case 'world' % A new world % actions na = 2; % R, L % reward field rew1 = 1; % rew2 = 6; R = zeros( ns, na); R(:,1) = +rew1; R(1,1) = -rew2; R(:,2) = -rew1; R(ns,2) = +rew2; digfig( 'init'); case 'agent' % A new agent nt = 0; % number of trials averew = 0; entval = 0; gamman = gamma; V = zeros(ns,1); Q = zeros(ns,na); alpha = 0.5; % learning rate beta = 2; % softmax action selection % gamma = 0.9; % discount factor % lambda = 0.5; % eligibility trace digfig( 'value'); case 'init' % A new epoch tmax = 50; state = zeros(tmax,1); action = zeros(tmax,1); reward = zeros(tmax,1); value = zeros(tmax,1); delta = zeros(tmax,1); % Elig = zeros(ns,1); % % run % case 'run' nt = nt+1; % trial count stop = 0; % for gui state(1) = ceil( rand*ns); % initial state value(1) = V(state(1)); for t=1:tmax if stop==1, break; end % for gui % select an action pr = exp( beta*Q(state(t),:)); prob = pr./(sum(pr)); % selection probablity act = find( cumsum(prob) > rand(1)); action(t) = act(1); % index of selected action % get reward r(t)... may also be called r(t+1) reward(t) = R(state(t),action(t)); % next state st = state(t) + action(t)*2 - 3; % +1 or -1 state(t+1) = mod( st-1, ns) + 1; % 1..ns % value of new state value(t+1) = V(state(t+1)); % TD error delta(t) = (1-gamma)*reward(t) + gamma*value(t+1) - value(t); % update values: TD zero V(state(t)) = V(state(t)) + alpha*delta(t); Q(state(t),action(t)) = Q(state(t),action(t)) + alpha*delta(t); % V = V + alpha*delta(t)*Elig; % update eligibility trace % Elig = gamma*lambda*Elig; % Elig(st) = 1; % replacing % Elig(st) = Elig(st) + (1-gamma*lambda); %disp([ state(t), action(t), value(t), reward(t), delta(t)]); end digfig( 'wave', t); case 'try' dig( 'init'); dig( 'run'); digfig( 'value'); dig( 'erhv'); % adapt gamma %dig( 'revalue'); case 'stop' stop = 1 % % adapt gamma % case 'erhv' % feedback by E[r] and H[V] averew(nt) = (mean(reward(1:t))+rew2)/(2*rew2); % mean reward: 0..1 pd = hist( value(1:t)/rew2, -1:0.1:1)*21/tmax; pd = pd + (pd==0); % replace 0 by 1 to avoid log(0) entval(nt) = -pd*log(pd)'/21; % normalized entropy of value:-inf..0 tau = 1/(1-gamma); a = 1; b = 0.5; c = -2; tau = tau + max( -1, min( 1, a*(0.75-averew(nt))+b*(entval(nt)-c))); tau = max( 1, tau); gamma = 1 - 1/tau; gamman(nt) = gamma; disp( [ averew(nt), entval(nt), gamman(nt)]); digin( 'grefresh'); case 'revalue' % calculate back V with different gammas dtau = 0.25; tau = 1/(1-gamma); taus = max( 1, [tau-dtau, tau, tau+dtau]); gams = 1 - 1./taus vals = zeros(tmax+1,3); vals(tmax+1,:) = value(tmax+1); for t=tmax:-1:1 % delta(t) = (1-gamma)*reward(t) + gamma*value(t+1) - value(t); vals(t,:) = (1-gams)*reward(t) + gams.*vals(t+1,:); end digfig( 'reval'); pds = hist( vals)/tmax pds = pds + (pds==0); % replace 0 by 1 to avoid log(0) ents = -sum( pds.*log(pds)); [maxe,maxi] = max(ents); gamma = 1 - 1/taus(maxi) digin( 'grefresh'); % % Misceraneous % case 'save' if nargin < 2, arg2 = 'dig.mat', end save( arg2); case 'load' whos( '-file', arg2); load( arg2); otherwise if isnumeric(arg), n = arg; % repeat arg times elseif ischar(arg), n = str2num(arg); else n = 0; end if n>0 & nargin>=2 for i = 1:n dig( arg2); if stop, break; end end else error( [ ' invalid command ', arg]); end end %%%%%%%%